From fd49ff8080b90687108c46f92321ce10ecc743dc Mon Sep 17 00:00:00 2001 From: ltqin Date: Wed, 20 Oct 2021 07:42:34 +0800 Subject: [PATCH] add nchw atomic , nhwc and nhwc atomic method for backward weight (#30) * add add new algorithm from v4r4r2 * program once issue * add split k functiion * redefine code * add a matrix unmerge * add b matrix unmerge k0 * trans a and b to gridegemm * nhwc init * no hacks and vector load * add hacks * modify some parameter * fix tuning prometer for fp32 * fix tuning prometer for fp16 * start change gridwise k split * init ok * revome a b matrix k0mk1 desc in grid * carewrite lculate gridsize * add kbatch to CalculateBottomIndex * remove some unused funtion * add clear data function before call kernel * out hacks * in hacks * rename device convolution file and function name * modify kBatch value * fix some tuning code * start from v4r4 nhwc * nhwc atomic is able to run * just for fp32 * enable nchw atomic * tweak * tweak * re-arrange gridwise gemm hot loop for wrw * add wrw v4r5 * v4r4r5 fp16 * v4r4r4 fp16 * v4r4r2 fp16 * V4R4R4XDLNHWC fp16 * V4R4R2XDLATOMICNCHW fp16 * adjust for fp16 * input gridsize * change kbatch to gridsize * testing wrw * clean up * k_batch to gridsize * fix bug * wrw v4r4r4 kbatch change to gride size * wrw v4r4r2 kbatch change to gride size * after merge , change gridwise gemm v2r4 * change MakeCBlockClusterAdaptor * other method use new gridwise gemm * clean up * chapad method nge to make_right_pad_transform * kbatch out from transform function * clean up and fix bug * fix bug * using function type reduce template parameters * using auto replace define fuction type * clean up Co-authored-by: ltqin Co-authored-by: Chao Liu Co-authored-by: Jing Zhang --- ...into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp | 147 ++++ ...into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp | 147 ++++ ...lution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp | 132 ++++ ...lution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp | 144 ++++ .../gridwise_gemm_xdlops_v2r4.hpp | 666 ++++++++++++++++++ ...plicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp | 5 +- ...mm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp | 258 +++++++ ...icit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp | 21 +- ...mm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp | 290 ++++++++ ...icit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp | 276 ++++++++ ...mm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp | 458 ++++++++++++ .../include/driver_gemm_xdlops_v2r4.hpp | 209 ++++++ .../src/conv_wrw_driver_offline.cpp | 168 ++++- host/host_tensor/include/device.hpp | 2 +- .../include/host_tensor_generator.hpp | 11 + 15 files changed, 2914 insertions(+), 20 deletions(-) create mode 100644 composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp create mode 100644 composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp create mode 100644 composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp create mode 100644 composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp create mode 100644 composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp create mode 100644 host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp create mode 100644 host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp create mode 100644 host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp create mode 100644 host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp create mode 100644 host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..e533ad9188 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp @@ -0,0 +1,147 @@ +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmK = N * Ho * Wo +// GemmN = C * Y * X +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad( + const TensorDescriptor& wei_k_c_y_x_grid_desc, + const TensorDescriptor& in_n_c_hi_wi_grid_desc, + const TensorDescriptor& out_n_k_ho_wo_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + GemmKBatchType GemmKBatch, + GemmKPadType GemmKPad) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); + const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = C * Y * X; + const auto GemmKTotal = N * Ho * Wo; + const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1); + + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( + in_n_c_hi_wi_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor( + in_n_c_hip_wip_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..213e1d6135 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,147 @@ +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: in +// B: wei +// C: out +// GemmM = N * Ho * Wo +// GemmN = K +// GemmK = Y * X * C +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad( + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + GemmKBatchType GemmKBatch, + GemmKPadType GemmKPad) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = Y * X * C; + const auto GemmN = K; + const auto GemmKTotal = N * Ho * Wo; + const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1); + + // A: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmktotal_gemmm_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: output tensor + const auto out_gemmktotal_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + return make_tuple(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..f1e1826d16 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,132 @@ +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: in +// B: wei +// C: out +// GemmM = N * Ho * Wo +// GemmN = K +// GemmK = Y * X * C +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = Y * X * C; + const auto GemmN = K; + const auto GemmK = N * Ho * Wo; + const auto GemmK0 = GemmK / GemmK1; + + // A: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmm_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: output tensor + const auto out_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(out_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..02e61c0ea3 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,144 @@ +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: out +// B: in +// C: wei +// GemmM = K +// GemmN = Y * X * C +// GemmKTotal = N * Ho * Wo +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad( + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + GemmKBatchType GemmKBatch, + GemmKPadType GemmKPad) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = Y * X * C; + const auto GemmKTotal = N * Ho * Wo; + const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1); + + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp new file mode 100644 index 0000000000..8a9c932f4c --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp @@ -0,0 +1,666 @@ +#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP +#define CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "blockwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_set.hpp" + +namespace ck { + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const CBlockClusterAdaptor c_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); +} +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_b_k0_m_k1_grid_desc, + const void CONSTANT* p_b_b_k0_n_k1_grid_desc, + const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const void CONSTANT* p_c_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + const auto a_b_k0_m_k1_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_a_b_k0_m_k1_grid_desc)); + const auto b_b_k0_n_k1_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_b_b_k0_n_k1_grid_desc)); + const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = + *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc)); + const auto c_block_cluster_adaptor = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor)); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); +} +#endif + +template +struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ __device__ static constexpr bool + CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + index_t M01, + index_t N01) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); + const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); + + if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) && + K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) && + K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) && + KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0)) + return false; + + // check M01, N01 + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + if(!(M0 % M01 == 0 && N0 % N01 == 0)) + return false; + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock) * KBatch; + + return grid_size; + } + + __host__ __device__ static constexpr auto + MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + using BlockwiseGemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + + return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( + const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto M00 = M0 / M01; + const auto N00 = N0 / N01; + + const auto kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_pass_through_transform(KBatch), + make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); + + const auto c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(KBatch, M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto c_blockid_to_kbatch_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor); + + return c_blockid_to_kbatch_m0_n0_block_cluster_adaptor; + } + + using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); + + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const CBlockClusterAdaptor& c_block_cluster_adaptor) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); + + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t k_batch_id = block_work_idx[I0]; + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto a_b_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto b_b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_K0_M_K1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_b_k0_m_k1_grid_desc, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseTensorSliceTransfer_v4, + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_b_k0_n_k1_grid_desc, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = p_shared_block; + FloatAB* p_b_block = p_shared_block + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; + constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{}; + + // hack to control index calculation when move slice window for A and B matrix for + // threadwise copy + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{}; + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{}; + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // preload data into LDS + { + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); + + a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); + } + + // main body + index_t k_block_data_begin = 0; + + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, + a_block_slice_copy_step, + a_k0_m_k1_grid_move_slice_window_step_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, + b_block_slice_copy_step, + b_k0_n_k1_grid_move_slice_window_step_hack); + + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); + + k_block_data_begin += KPerBlock; + } while(k_block_data_begin < (K0 - KPerBlock)); + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, I1, I1, Number{}, I1, Number{}, I1)); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; + + const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3], + m_thread_data_on_grid_idx[I4], + n_thread_data_on_grid_idx[I2])}; + + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_grid_buf, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); + } + } +}; // namespace ck + +} // namespace ck +#endif diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp index b5ff1db296..8258aa0e66 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp @@ -3,6 +3,7 @@ #include "host_tensor.hpp" #include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp" #include "driver_gemm_xdlops_v2r3.hpp" +#include "debug.hpp" template +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp" +#include "driver_gemm_xdlops_v2r4.hpp" + +template +void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + Tensor& wei_k_c_y_x, + const Tensor& out_n_k_ho_wo, + GridSizeType desired_grid_size, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TIn) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); + +#if 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 64, 1>; + // using vector load 4, so config's wo*ho must be a multiple of 4 + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto N = in_n_c_hi_wi_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_desc.GetLength(I1); + + const auto Ho = out_n_k_ho_wo_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_desc.GetLength(I2); + const auto X = wei_k_c_y_x_desc.GetLength(I3); + + const auto GemmM = K; + const auto GemmN = Y * X * C; + const auto GemmKTotal = N * Ho * Wo; + + const auto GemmK = GemmKTotal / GemmK1; + + const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); + const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); + const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch)); + const index_t GemmK0 = BatchLen * GemmKPerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; + + std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN + << " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad + << std::endl; + const auto descs = + transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}, + GemmKBatch, + GemmKPad); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto wei_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmB + Sequence<0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmM + Sequence<0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1 + make_tuple(Sequence<0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemB + Sequence<0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmM + Sequence<0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1 + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmB + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmB + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1 + + constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 1, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{}; + + const auto driver_gemm_xdlops = + driver_gemm_xdlops_v2r4, + Sequence<0, 2, 1, 3>, + 3, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1, + Sequence<0, 2, 1, 3>, + Sequence<0, 2, 1, 3>, + 3, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<3, 0, 1, 2, 7, 5, 4, 6>, + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false, + true, + true>; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast(calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + driver_gemm_xdlops(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + 0); + // copy result back to host + wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index b8ecfb4be9..ac75c56bf5 100644 --- a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -4,7 +4,8 @@ #include "transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" #include "driver_gemm_xdlops_v2r3.hpp" -template & in_n_c_hi_wi, - Tensor& wei_k_c_y_x, + const Tensor& in_n_c_hi_wi, + Tensor& wei_k_c_y_x, const Tensor& out_n_k_ho_wo, ck::index_t nrepeat) { @@ -35,8 +36,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; - DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TIn) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TWei) * wei_k_c_y_x.mDesc.GetElementSpace()); DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); @@ -47,7 +48,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); -#if 1 +#if 0 // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 constexpr index_t BlockSize = 256; @@ -164,9 +165,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk { float ave_time = driver_gemm_xdlops_v2r3< BlockSize, - TInWei, + TIn, TAcc, - TOut, + TWei, InMemoryDataOperationEnum_t::Set, decltype(out_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc), @@ -207,8 +208,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk true, // ABlockLdsExtraM true // BBlockLdsExtraN >(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), out_gemmk0_gemmm_gemmk1_grid_desc, in_gemmk0_gemmn_gemmk1_grid_desc, wei_gemmm_gemmn_grid_desc, diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..579c7a1200 --- /dev/null +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,290 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp" +#include "driver_gemm_xdlops_v2r4.hpp" + +template +void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + Tensor& wei_k_y_x_c, + const Tensor& out_n_ho_wo_k, + GridSizeType desired_grid_size, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto N = in_n_hi_wi_c_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_desc.GetLength(I3); + + const auto Ho = out_n_ho_wo_k_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_desc.GetLength(I1); + const auto X = wei_k_y_x_c_desc.GetLength(I2); + + const auto GemmM = Y * X * C; + const auto GemmN = K; + const auto GemmKTotal = N * Ho * Wo; + + const auto GemmK = GemmKTotal / GemmK1; + + const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); + const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); + const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch)); + const index_t GemmK0 = BatchLen * GemmKPerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; + + std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN + << " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad + << std::endl; + + const auto descs = + transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad( + in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}, + GemmKBatch, + GemmKPad); + + const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto wei_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmKBatch + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmKBatch + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1 + + constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{}; + + constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + const auto driver_gemm_xdlops = driver_gemm_xdlops_v2r4< + BlockSize, + TIn, + TAcc, + TWei, + InMemoryDataOperationEnum_t::AtomicAdd, + decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc), + decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc), + decltype(wei_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerXDL, + GemmNPerXDL, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<0, 1, 2, 3>, + Sequence<0, 1, 2, 3>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmM, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<0, 1, 2, 3>, + Sequence<0, 1, 2, 3>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 6, + GemmCThreadTransferDstScalarPerVector, + decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false, // CAccessOrderMRepeatNRepeat + true, + true>; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + { + + float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + driver_gemm_xdlops(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + 0); + // copy result back to host + wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..bc5d599604 --- /dev/null +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,276 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" +#include "debug.hpp" + +template +void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + Tensor& wei_k_y_x_c, + const Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; + +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( + in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto wei_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1 + + constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{}; + + constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_gemm_xdlops_v2r3< + BlockSize, + TIn, + TAcc, + TWei, + InMemoryDataOperationEnum_t::Set, + decltype(in_gemmk0_gemmm_gemmk1_grid_desc), + decltype(out_gemmk0_gemmn_gemmk1_grid_desc), + decltype(wei_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerXDL, + GemmNPerXDL, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + GemmABlockTransferSrcScalarPerVector_GemmM, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false, // CAccessOrderMRepeatNRepeat + true, + true>(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + in_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + in_gemmk0_gemmm_gemmk1_grid_step_hacks, + out_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..29b404f7d0 --- /dev/null +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,458 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp" +#include "driver_gemm_xdlops_v2r4.hpp" + +template +void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + Tensor& wei_k_y_x_c, + const Tensor& out_n_ho_wo_k, + GridSizeType desired_grid_size, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4], C 128, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4], C 128, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8], C 128, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8], C 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 8], C 64, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [64, 128, 4, 8], C 64, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 64; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 16, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [64, 64, 4, 8], C 32, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 64; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto N = in_n_hi_wi_c_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_desc.GetLength(I3); + + const auto Ho = out_n_ho_wo_k_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_desc.GetLength(I1); + const auto X = wei_k_y_x_c_desc.GetLength(I2); + + const auto GemmM = K; + const auto GemmN = Y * X * C; + const auto GemmKTotal = N * Ho * Wo; + + const auto GemmK = GemmKTotal / GemmK1; + + const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); + const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); + const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch)); + const index_t GemmK0 = BatchLen * GemmKPerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; + + std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN + << " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad + << std::endl; + + const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad( + in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}, + GemmKBatch, + GemmKPad); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto wei_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{}; + + const auto driver_gemm_xdlops = driver_gemm_xdlops_v2r4< + BlockSize, + TIn, + TAcc, + TWei, + InMemoryDataOperationEnum_t::AtomicAdd, + decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc), + decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc), + decltype(wei_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerXDL, + GemmNPerXDL, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<0, 1, 2, 3>, + Sequence<0, 1, 2, 3>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmM, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<0, 1, 2, 3>, + Sequence<0, 1, 3, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false, // CAccessOrderMRepeatNRepeat + true, + true>; + + // timing + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + { + float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // verification + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + driver_gemm_xdlops(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + 0); + // copy result back to host + wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data()); +} diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp new file mode 100644 index 0000000000..65c4f62367 --- /dev/null +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp @@ -0,0 +1,209 @@ +#ifndef DRIVER_GEMM_XDLOPS_V2R4 +#define DRIVER_GEMM_XDLOPS_V2R4 + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r4.hpp" + +template +__host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + ck::index_t M01, + ck::index_t N01, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks, + ck::index_t nrepeat) + +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + using GridwiseGemm = + GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4; + + { + std::cout << "a_b_k0_m_k1_grid_desc{" << a_b_k0_m_k1_grid_desc.GetLength(I0) << ", " + << a_b_k0_m_k1_grid_desc.GetLength(I1) << ", " + << a_b_k0_m_k1_grid_desc.GetLength(I2) << ", " + << a_b_k0_m_k1_grid_desc.GetLength(I3) << "}" << std::endl; + + std::cout << "b_b_k0_n_k1_grid_desc{" << b_b_k0_n_k1_grid_desc.GetLength(I0) << ", " + << b_b_k0_n_k1_grid_desc.GetLength(I1) << ", " + << b_b_k0_n_k1_grid_desc.GetLength(I2) << ", " + << b_b_k0_n_k1_grid_desc.GetLength(I3) << "}" << std::endl; + + std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " + << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity( + a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_m_n_grid_desc, M01, N01)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r4 has invalid setting"); + } + + const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = + GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); + + using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); + + const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); + + const auto c_block_cluster_adaptor = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, M01, N01, KBatch); + + using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); + + const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc, KBatch); + { + std::cout << "gridSize : " << grid_size << std::endl; + } + + const auto kernel = kernel_gemm_xdlops_v2r4, + remove_reference_t, + remove_reference_t, + remove_reference_t>; + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + float ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); + +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_b_k0_m_k1_grid_desc_dev_buf(sizeof(ABK0MK1GridDesc)); + DeviceMem b_b_k0_n_k1_grid_desc_dev_buf(sizeof(BBK0NK1GridDesc)); + DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc)); + DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); + + a_b_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_b_k0_m_k1_grid_desc); + b_b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_b_k0_n_k1_grid_desc); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); + c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); + + float ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_b_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); +#endif + return ave_time; +} +#endif diff --git a/host/driver_offline/src/conv_wrw_driver_offline.cpp b/host/driver_offline/src/conv_wrw_driver_offline.cpp index 310dbfe1eb..50f4d6a9b3 100644 --- a/host/driver_offline/src/conv_wrw_driver_offline.cpp +++ b/host/driver_offline/src/conv_wrw_driver_offline.cpp @@ -14,13 +14,25 @@ #include "host_conv_bwd_weight.hpp" #include "device_tensor.hpp" #include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" +#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" +#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp" +#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp" +#include "device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp" #define USE_DYNAMIC_MODE 1 -#define USE_CONV_WRW_V4R4R2_XDL_NCHW 1 +#define USE_CONV_WRW_V4R4R2_XDL_NCHW 0 +#define USE_CONV_WRW_V4R4R4_XDL_NHWC 0 +#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 0 +#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 0 +#define USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC 1 enum ConvBackwardWeightAlgo { - V4R4R2XDLNCHW, + V4R4R2XDLNCHW, // 0 + V4R4R4XDLNHWC, // 1 + V4R4R2XDLATOMICNCHW, // 2 + V4R4R4XDLATOMICNHWC, // 3 + V4R4R5XDLATOMICNHWC, // 4 }; int main(int argc, char* argv[]) @@ -37,10 +49,11 @@ int main(int argc, char* argv[]) #if USE_DYNAMIC_MODE // dynamic mode - if(argc != 22) + if(argc != 23) { printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); + printf("additional: desired_grid_size\n"); exit(1); } @@ -68,6 +81,8 @@ int main(int argc, char* argv[]) const index_t in_right_pad_h = std::stoi(argv[20]); const index_t in_right_pad_w = std::stoi(argv[21]); + const index_t desired_grid_size = std::stoi(argv[22]); + const index_t YEff = (Y - 1) * conv_dilation_h + 1; const index_t XEff = (X - 1) * conv_dilation_w + 1; @@ -114,16 +129,19 @@ int main(int argc, char* argv[]) #if 0 using in_data_t = float; + using wei_data_t = float; using acc_data_t = float; using out_data_t = float; #elif 1 using in_data_t = half_t; - using acc_data_t = float; using out_data_t = half_t; + using acc_data_t = float; + using wei_data_t = float; #elif 1 using in_data_t = int8_t; - using acc_data_t = int32_t; using out_data_t = int8_t; + using acc_data_t = int32_t; + using wei_data_t = int8_t; #endif std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); @@ -164,8 +182,8 @@ int main(int argc, char* argv[]) } Tensor in(in_lengths_host); - Tensor wei_device(wei_lengths_host); - Tensor wei_host(wei_lengths_host); + Tensor wei_device(wei_lengths_host); + Tensor wei_host(wei_lengths_host); Tensor out(out_lengths_host); std::cout << "layout: " << layout << std::endl; @@ -231,6 +249,26 @@ int main(int argc, char* argv[]) in_right_pads_dev); }; + auto f_make_for_device_nhwc = [&]() { + const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); + const auto wei_lengths_dev = make_tuple(K, Y, X, C); + const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + + // set zero to wei_device + wei_device.GenerateTensorValue(GeneratorTensor_0{}, num_thread); #if USE_CONV_WRW_V4R4R2_XDL_NCHW if(algo == ConvBackwardWeightAlgo::V4R4R2XDLNCHW) { @@ -242,6 +280,7 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nchw(); device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( tmp[I0], @@ -258,6 +297,121 @@ int main(int argc, char* argv[]) } #endif +#if USE_CONV_WRW_V4R4R4_XDL_NHWC + if(algo == ConvBackwardWeightAlgo::V4R4R4XDLNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei_device, + out, + nrepeat); + } +#endif + +#if USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW + if(algo == ConvBackwardWeightAlgo::V4R4R2XDLATOMICNCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw< + in_data_t, + wei_data_t, + acc_data_t, + out_data_t>(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei_device, + out, + desired_grid_size, + nrepeat); + } +#endif + +#if USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC + if(algo == ConvBackwardWeightAlgo::V4R4R4XDLATOMICNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk< + in_data_t, + wei_data_t, + acc_data_t, + out_data_t>(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei_device, + out, + desired_grid_size, + nrepeat); + } +#endif + +#if USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC + if(algo == ConvBackwardWeightAlgo::V4R4R5XDLATOMICNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk< + in_data_t, + wei_data_t, + acc_data_t, + out_data_t>(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei_device, + out, + desired_grid_size, + nrepeat); + } +#endif + if(do_verification) { host_direct_convolution_backward_weights(out, diff --git a/host/host_tensor/include/device.hpp b/host/host_tensor/include/device.hpp index 9b66f24f7a..cb1a6effa1 100644 --- a/host/host_tensor/include/device.hpp +++ b/host/host_tensor/include/device.hpp @@ -2,6 +2,7 @@ #define DEVICE_HPP #include +#include #include #include #include "hip/hip_runtime.h" @@ -80,5 +81,4 @@ float launch_and_time_kernel( return timer.GetElapsedTime() / nrepeat; } - #endif diff --git a/host/host_tensor/include/host_tensor_generator.hpp b/host/host_tensor/include/host_tensor_generator.hpp index 7c09843d01..b0d53995ed 100644 --- a/host/host_tensor/include/host_tensor_generator.hpp +++ b/host/host_tensor/include/host_tensor_generator.hpp @@ -15,6 +15,17 @@ struct GeneratorTensor_1 } }; +struct GeneratorTensor_0 +{ + int value = 0; + + template + float operator()(Is...) + { + return value; + } +}; + struct GeneratorTensor_2 { int min_value = 0;