diff --git a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp index 9c60e8c3ac..fa78d76965 100644 --- a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp +++ b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp @@ -21,8 +21,8 @@ template __host__ __device__ constexpr auto transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( @@ -33,8 +33,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, const InRightPads& in_right_pads, - Number, - Number, + IYTilda i_ytilda, + IXTilda i_xtilda, Number) { constexpr auto I0 = Number<0>{}; @@ -42,9 +42,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr auto GemmK1 = Number{}; - constexpr auto IYTilda = Number{}; - constexpr auto IXTilda = Number{}; + constexpr auto GemmK1 = Number{}; const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); @@ -98,8 +96,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin; // GemmK is different for each GEMM - const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda); - const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilda, YTilda); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilda, XTilda); const auto K1 = GemmK1; const auto K0 = K / K1; @@ -183,8 +181,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(XDot, I0, XDotSlice), - make_freeze_transform(IYTilda), - make_freeze_transform(IXTilda), + make_freeze_transform(i_ytilda), + make_freeze_transform(i_xtilda), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -241,9 +239,9 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor( in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, make_tuple(make_pass_through_transform(N), - make_freeze_transform(IYTilda), + make_freeze_transform(i_ytilda), make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), - make_freeze_transform(IXTilda), + make_freeze_transform(i_xtilda), make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, @@ -271,5 +269,84 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( in_gemmm_gemmn_grid_desc); } +// A: out +// B: wei +// C: in +// Number of GEMMs = 1 +// GemmM = N * Ho * Wo +// GemmN = C +// GemmK = K +template +__host__ __device__ constexpr auto +transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1( + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const TensorDescriptor& /* wei_k_y_x_c_grid_desc */, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const ConvStrides& conv_strides, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto K1 = GemmK1; + const auto K0 = K / K1; + + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); +} + } // namespace ck #endif diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..e533ad9188 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp @@ -0,0 +1,147 @@ +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmK = N * Ho * Wo +// GemmN = C * Y * X +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad( + const TensorDescriptor& wei_k_c_y_x_grid_desc, + const TensorDescriptor& in_n_c_hi_wi_grid_desc, + const TensorDescriptor& out_n_k_ho_wo_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + GemmKBatchType GemmKBatch, + GemmKPadType GemmKPad) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); + const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = C * Y * X; + const auto GemmKTotal = N * Ho * Wo; + const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1); + + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( + in_n_c_hi_wi_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor( + in_n_c_hip_wip_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..949f044b7d --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp @@ -0,0 +1,129 @@ +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmK = N * Ho * Wo +// GemmN = C * Y * X +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( + const TensorDescriptor& wei_k_c_y_x_grid_desc, + const TensorDescriptor& in_n_c_hi_wi_grid_desc, + const TensorDescriptor& out_n_k_ho_wo_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); + const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = C * Y * X; + const auto GemmK = N * Ho * Wo; + const auto GemmK0 = GemmK / GemmK1; + + // weight tensor + const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // input tensor + const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( + in_n_c_hi_wi_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor( + in_n_c_hip_wip_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmk_gemmn_grid_desc = + transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(out_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..213e1d6135 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,147 @@ +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: in +// B: wei +// C: out +// GemmM = N * Ho * Wo +// GemmN = K +// GemmK = Y * X * C +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad( + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + GemmKBatchType GemmKBatch, + GemmKPadType GemmKPad) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = Y * X * C; + const auto GemmN = K; + const auto GemmKTotal = N * Ho * Wo; + const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1); + + // A: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmktotal_gemmm_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: output tensor + const auto out_gemmktotal_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + return make_tuple(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..f1e1826d16 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,132 @@ +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: in +// B: wei +// C: out +// GemmM = N * Ho * Wo +// GemmN = K +// GemmK = Y * X * C +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = Y * X * C; + const auto GemmN = K; + const auto GemmK = N * Ho * Wo; + const auto GemmK0 = GemmK / GemmK1; + + // A: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmm_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: output tensor + const auto out_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(out_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..02e61c0ea3 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,144 @@ +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: out +// B: in +// C: wei +// GemmM = K +// GemmN = Y * X * C +// GemmKTotal = N * Ho * Wo +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad( + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + GemmKBatchType GemmKBatch, + GemmKPadType GemmKPad) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = Y * X * C; + const auto GemmKTotal = N * Ho * Wo; + const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1); + + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/multi_index_transform.hpp b/composable_kernel/include/tensor_description/multi_index_transform.hpp index 42a5a875b7..1a25e99f3b 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform.hpp +++ b/composable_kernel/include/tensor_description/multi_index_transform.hpp @@ -1327,6 +1327,129 @@ struct Merge_v2r2_magic_division } }; +// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to +// be used for low_lengths that are known at compile time and are power of 2, otherwise performance +// will be very bad +template +struct Merge_v3_division_mod +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = + decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Merge_v3_division_mod() = default; + + __host__ __device__ constexpr Merge_v3_division_mod(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + // division and mod + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_low(i) = tmp / this->low_lengths_scan_[i]; + tmp %= this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + constexpr auto INm1 = Number{}; + + index_t tmp = idx_up_new[I0]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + const index_t tmp2 = idx_low[i]; + idx_low(i) = tmp / this->low_lengths_scan_[i]; + idx_diff_low(i) = idx_low[i] - tmp2; + tmp %= this->low_lengths_scan_[i]; + }); + + const index_t tmp2 = idx_low[INm1]; + idx_low(INm1) = tmp; + idx_diff_low(INm1) = idx_low[INm1] - tmp2; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Merge_v3_direct_division_mod, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan_ "); + print_multi_index(low_lengths_scan_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + template struct UnMerge { diff --git a/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp b/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp index 6d4e01888b..9a73799173 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp +++ b/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp @@ -31,7 +31,7 @@ __host__ __device__ constexpr auto make_left_pad_transform( return LeftPad{low_length, left_pad}; } -template +template __host__ __device__ constexpr auto make_right_pad_transform( const LowLength& low_length, const RightPadLength& right_pad, @@ -52,22 +52,36 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng template __host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths) { -#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION +#if CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION + return make_merge_transform_v2_magic_division(low_lengths); +#else + return make_merge_transform_v1_carry_check(low_lengths); +#endif +} + +template +__host__ __device__ constexpr auto +make_merge_transform_v1_carry_check(const LowLengths& low_lengths) +{ return Merge_v1_carry_check{low_lengths}; -#else -#if 1 - return Merge_v2_magic_division{low_lengths}; -#else - return Merge_v2r2_magic_division{low_lengths}; -#endif -#endif } template __host__ __device__ constexpr auto make_merge_transform_v2_magic_division(const LowLengths& low_lengths) { +#if 1 return Merge_v2_magic_division{low_lengths}; +#else + return Merge_v2r2_magic_division{low_lengths}; +#endif +} + +template +__host__ __device__ constexpr auto +make_merge_transform_v3_division_mod(const LowLengths& low_lengths) +{ + return Merge_v3_division_mod{low_lengths}; } template diff --git a/composable_kernel/include/tensor_description/tensor_adaptor.hpp b/composable_kernel/include/tensor_description/tensor_adaptor.hpp index 3b647e433a..50a8088bba 100644 --- a/composable_kernel/include/tensor_description/tensor_adaptor.hpp +++ b/composable_kernel/include/tensor_description/tensor_adaptor.hpp @@ -189,8 +189,7 @@ struct TensorAdaptor bool is_known = true; static_for<0, Transforms::Size(), 1>{}([&](auto i) { - is_known &= - remove_cv_t>::IsKnownAtCompileTime(); + is_known &= remove_cvref_t::IsKnownAtCompileTime(); }); return is_known && is_known_at_compile_time::value; diff --git a/composable_kernel/include/tensor_description/tensor_descriptor.hpp b/composable_kernel/include/tensor_description/tensor_descriptor.hpp index a6a57ba63b..8f6a5a3e43 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor.hpp +++ b/composable_kernel/include/tensor_description/tensor_descriptor.hpp @@ -185,8 +185,7 @@ struct TensorDescriptor bool is_known = true; static_for<0, Transforms::Size(), 1>{}([&](auto i) { - is_known &= - remove_cv_t>::IsKnownAtCompileTime(); + is_known &= remove_cvref_t::IsKnownAtCompileTime(); }); return is_known && is_known_at_compile_time::value && @@ -587,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& template using TensorCoordinate_t = decltype(make_tensor_coordinate( - TensorDesc{}, MultiIndex>::GetNumOfDimension()>{})); + TensorDesc{}, MultiIndex::GetNumOfDimension()>{})); template using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step( - TensorDesc{}, MultiIndex>::GetNumOfDimension()>{})); + TensorDesc{}, MultiIndex::GetNumOfDimension()>{})); } // namespace ck #endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp index 03f889649e..5cc2f2393e 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp @@ -110,13 +110,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 const BThreadBuffer& b_thread_buf, CThreadBuffer& c_thread_buf) const { - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - "wrong! inconsistent type"); + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); constexpr auto I0 = Number<0>{}; diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp index ee6a0b7427..36c6783204 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp @@ -4,21 +4,22 @@ #include "common_header.hpp" #include "threadwise_tensor_slice_transfer.hpp" #include "xdlops_gemm.hpp" +#include "tensor_adaptor.hpp" namespace ck { template -struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 +struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 { - - using CIndex = MultiIndex<2>; - static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -26,111 +27,169 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 static constexpr index_t WaveSize = 64; - static constexpr index_t M0 = ABlockDesc{}.GetLength(I1); - static constexpr index_t M1 = ABlockDesc{}.GetLength(I2); + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); - static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); - static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); + static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0); - static constexpr auto xdlops_gemm = XdlopsGemm{}; + static constexpr auto xdlops_gemm = XdlopsGemm{}; - static constexpr index_t MWaves = M1 / MPerWave; - static constexpr index_t NWaves = N1 / NPerWave; + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); - static constexpr index_t MRepeat = M0; - static constexpr index_t NRepeat = N0; + StaticBufferV2, MRepeat * NRepeat, true> + c_thread_buf_; - __device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); } + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } - __device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); } + __device__ static auto GetWaveIdx() + { + const index_t thread_id = get_thread_local_1d_id(); - __device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); } + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } __device__ static auto CalculateAThreadOriginDataIndex() { - const index_t thread_id = get_thread_local_1d_id(); - const index_t waveId = thread_id / WaveSize; - const index_t laneId = thread_id % WaveSize; - const index_t waveId_m = waveId / NWaves; + const auto wave_idx = GetWaveIdx(); - if constexpr(xdlops_gemm.IsKReduction) - { - const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId); - const index_t k_offset = xdlops_gemm.GetBlkId(laneId); - return make_tuple(k_offset, 0, m_offset, 0); - } - else - { - const index_t m_offset = waveId_m * MPerWave + laneId; - const index_t k_offset = 0; - return make_tuple(k_offset, 0, m_offset, 0); - } + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(xdlops_a_idx[I0], 0, waveId_m, xdlops_a_idx[I1], 0); } __device__ static auto CalculateBThreadOriginDataIndex() { - const index_t thread_id = get_thread_local_1d_id(); - const index_t waveId = thread_id / WaveSize; - const index_t laneId = thread_id % WaveSize; - const index_t waveId_n = waveId % NWaves; + const auto wave_idx = GetWaveIdx(); - if constexpr(xdlops_gemm.IsKReduction) - { - const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); - const index_t k_offset = xdlops_gemm.GetBlkId(laneId); - return make_tuple(k_offset, 0, n_offset, 0); - } - else - { - const index_t n_offset = waveId_n * NPerWave + laneId; - const index_t k_offset = 0; - return make_tuple(k_offset, 0, n_offset, 0); - } + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(xdlops_b_idx[I0], 0, waveId_n, xdlops_b_idx[I1], 0); } template - __device__ static CIndex + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { + const auto wave_idx = GetWaveIdx(); - const index_t waveId = get_thread_local_1d_id() / WaveSize; + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; - const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); - const index_t waveId_m = waveId / NWaves; - const index_t waveId_n = waveId % NWaves; + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); - const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0]; - const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1]; + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); - return CIndex{m_offset, n_offset}; + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); } - __device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1() - : a_thread_copy_{CalculateAThreadOriginDataIndex()}, - b_thread_copy_{CalculateBThreadOriginDataIndex()} + __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() { - static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); - static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), - "wrong! K dimension not consistent"); + static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0), + "wrong! K0 dimension not consistent"); - static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), + static_assert(AK0MK1BlockDesc{}.GetLength(I2) == BK0NK1BlockDesc{}.GetLength(I2), "wrong! K1 dimension not consistent"); static_assert(BlockSize == MWaves * NWaves * WaveSize, "BlockSize != MWaves * NWaves * WaveSize\n"); - static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!"); - - constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); - - static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); - - static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!"); + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); } + __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor() + { + constexpr auto c_m0_n0_m1_n1_m2_n2_block_desc = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_block_desc); + } + + template + __host__ __device__ static constexpr auto + MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto c_m0_n0_m1_n1_m2_n2_grid_desc = transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc); + } + + __host__ __device__ static constexpr auto MakeAK0M0M1M2K1BlockDescriptor() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); + } + + __host__ __device__ static constexpr auto MakeBK0N0N1N2K1BlockDescriptor() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); + } + + static constexpr auto a_k0_m0_m1_m2_k1_block_desc = MakeAK0M0M1M2K1BlockDescriptor(); + static constexpr auto b_k0_n0_n1_n2_k1_block_desc = MakeBK0N0N1N2K1BlockDescriptor(); + template __device__ void Run(const ABlockBuffer& a_block_buf, const BBlockBuffer& b_block_buf, @@ -141,49 +200,43 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); - - vector_type a_thread_vec; - - vector_type b_thread_vec; - - static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { // read A - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(k, I0, I0, I0), + a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc, + make_tuple(I0, m0, I0, I0, I0), a_block_buf, a_thread_desc_, - make_tuple(I0, I0, I0, I0), + make_tuple(I0, I0, I0, I0, I0), a_thread_buf); - // read B - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(k, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc, + make_tuple(I0, n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0, I0), + b_thread_buf); - using mfma_input_type = - typename vector_type::type; + static_for<0, K0, xdlops_gemm.K0PerXdlops>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, a_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) { - a_thread_vec.template AsType()(Number{}) = a_thread_buf[Number{}]; - }); + static_for<0, K1, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); - static_for<0, b_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) { - b_thread_vec.template AsType()(Number{}) = b_thread_buf[Number{}]; - }); + using mfma_input_type = + typename vector_type::type; - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - xdlops_gemm.template Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0)); + + xdlops_gemm.template Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVector(Number{})); }); }); }); @@ -192,332 +245,37 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 private: // A[K, M] static constexpr auto a_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, I1, Number{})); // B[K, N] static constexpr auto b_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, I1, Number{})); static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, + Sequence, + Sequence<0, 1, 2, 3, 4>, + 4, K1, - 1>; + K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, + Sequence, + Sequence<0, 1, 2, 3, 4>, + 4, K1, - 1>; + K1>; - AThreadCopy a_thread_copy_; - BThreadCopy b_thread_copy_; -}; - -template -struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline -{ - - using CIndex = MultiIndex<2>; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - static constexpr auto xdlops_gemm = XdlopsGemm{}; - - static constexpr index_t WaveSize = 64; - - static constexpr index_t M0 = ABlockDesc{}.GetLength(I1); - static constexpr index_t M1 = ABlockDesc{}.GetLength(I2); - - static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); - static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); - - static constexpr index_t MWaves = M1 / MPerWave; - static constexpr index_t NWaves = N1 / NPerWave; - - static constexpr index_t MRepeat = M0; - static constexpr index_t NRepeat = N0; - - __device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); } - - __device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); } - - __device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); } - - __device__ static auto CalculateAThreadOriginDataIndex() - { - const index_t thread_id = get_thread_local_1d_id(); - const index_t waveId = thread_id / WaveSize; - const index_t laneId = thread_id % WaveSize; - const index_t waveId_m = waveId / NWaves; - - if constexpr(xdlops_gemm.IsKReduction) - { - const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId); - const index_t k_offset = xdlops_gemm.GetBlkId(laneId); - return make_tuple(k_offset, 0, m_offset, 0); - } - else - { - const index_t m_offset = waveId_m * MPerWave + laneId; - const index_t k_offset = 0; - return make_tuple(k_offset, 0, m_offset, 0); - } - } - - __device__ static auto CalculateBThreadOriginDataIndex() - { - const index_t thread_id = get_thread_local_1d_id(); - const index_t waveId = thread_id / WaveSize; - const index_t laneId = thread_id % WaveSize; - const index_t waveId_n = waveId % NWaves; - - if constexpr(xdlops_gemm.IsKReduction) - { - const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); - const index_t k_offset = xdlops_gemm.GetBlkId(laneId); - return make_tuple(k_offset, 0, n_offset, 0); - } - else - { - const index_t n_offset = waveId_n * NPerWave + laneId; - const index_t k_offset = 0; - return make_tuple(k_offset, 0, n_offset, 0); - } - } - - template - __device__ static CIndex - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) - { - - const index_t waveId = get_thread_local_1d_id() / WaveSize; - - const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); - - const index_t waveId_m = waveId / NWaves; - const index_t waveId_n = waveId % NWaves; - - const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0]; - const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1]; - - return CIndex{m_offset, n_offset}; - } - - __device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline() - : a_thread_copy_{CalculateAThreadOriginDataIndex()}, - b_thread_copy_{CalculateBThreadOriginDataIndex()} - { - static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), - "wrong! K dimension not consistent"); - - static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), - "wrong! K1 dimension not consistent"); - - static_assert(BlockSize == MWaves * NWaves * WaveSize, - "BlockSize != MWaves * NWaves * WaveSize\n"); - - static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!"); - - constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); - - static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); - - static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!"); - } - - template - __device__ void Run(const ABlockBuffer& a_block_buf, - const BBlockBuffer& b_block_buf, - CThreadBuffer& c_thread_buf) const - { - auto a_thread_buf = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - - constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); - - // read A_sub_0 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0, I0), - a_thread_buf); - - // read B_sub_0 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); - - // read B_sub_1 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(I0, I1, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I1, I0, I0), - b_thread_buf); - - // read A_sub_1 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(I0, I1, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I1, I0, I0), - a_thread_buf); - - // C_sub_00 += transpose(A_sub_0) * B_sub_0 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // C_sub_01 += transpose(A_sub_0) * B_sub_1 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - static_for{}([&](auto k) { - // read A_sub_0 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(k, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0, I0), - a_thread_buf); - - // C_sub_10 += transpose(A_sub_1) * B_sub_0 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // read B_sub_0 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(k, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); - - // C_sub_11 += transpose(A_sub_1) * B_sub_1 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // read B_sub_1 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(k, I1, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I1, I0, I0), - b_thread_buf); - - // read A_sub_1 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(k, I1, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I1, I0, I0), - a_thread_buf); - - // C_sub_00 += transpose(A_sub_0) * B_sub_0 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // C_sub_01 += transpose(A_sub_0) * B_sub_1 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - }); - - // C_sub_10 += transpose(A_sub_1) * B_sub_0 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // C_sub_11 += transpose(A_sub_1) * B_sub_1 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - } - - private: - // A[K, M] - static constexpr auto a_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); - - // B[K, N] - static constexpr auto b_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); - - static constexpr auto c_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); - - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, - 1, // K1, - 1>; - - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, - 1, // K1, - 1>; - - AThreadCopy a_thread_copy_; - BThreadCopy b_thread_copy_; + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; }; } // namespace ck diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index 207f73072f..86e047c965 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -18,8 +18,9 @@ template + typename CM0N0M1N1M2M3M4N2GridDesc, + typename CBlockClusterAdaptor, + bool HasMainKBlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -29,7 +30,7 @@ __global__ void FloatC* __restrict__ p_c_grid, const AK0MK1GridDesc a_k0_m_k1_grid_desc, const BK0NK1GridDesc b_k0_n_k1_grid_desc, - const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const CBlockClusterAdaptor c_block_cluster_adaptor) { constexpr index_t shared_block_size = @@ -37,14 +38,14 @@ __global__ void __shared__ FloatAB p_shared_block[shared_block_size]; - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, - c_block_cluster_adaptor); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); } #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -63,7 +64,7 @@ __global__ void FloatC* __restrict__ p_c_grid, const void CONSTANT* p_a_k0_m_k1_grid_desc, const void CONSTANT* p_b_k0_n_k1_grid_desc, - const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, + const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const void CONSTANT* p_c_block_cluster_adaptor) { constexpr index_t shared_block_size = @@ -73,21 +74,22 @@ __global__ void cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc)); const auto b_k0_n_k1_grid_desc = *reinterpret_cast( cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc)); - const auto c_m0_m1_m2_n_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_m0_m1_m2_n_grid_desc)); + const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = + *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc)); const auto c_block_cluster_adaptor = *reinterpret_cast( cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor)); __shared__ FloatAB p_shared_block[shared_block_size]; - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, - c_block_cluster_adaptor); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); } #endif @@ -101,9 +103,9 @@ template + bool CAccessOrderMRepeatNRepeat, + bool ABlockLdsExtraM, + bool BBlockLdsExtraN> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; // K1 should be Number<...> static constexpr auto K1 = Number{}; @@ -147,14 +155,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto max_lds_align = K1; // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size = @@ -166,27 +194,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ __device__ static constexpr bool CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc) + const CMNGridDesc& c_m_n_grid_desc, + index_t M01, + index_t N01) { - // TODO: turn on this static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); + static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerXDL)) == 0, + "Invalid tuning param!"); + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); const auto N = b_k0_n_k1_grid_desc.GetLength(I1); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K1 == a_k0_m_k1_grid_desc.GetLength(I2) && + K1 == b_k0_n_k1_grid_desc.GetLength(I2))) + return false; - return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && - K0 == b_k0_n_k1_grid_desc.GetLength(I0) && - K1 == a_k0_m_k1_grid_desc.GetLength(I2) && - K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && - (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0) && - (MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0); + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check M01, N01 + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + if(!(M0 % M01 == 0 && N0 % N01 == 0)) + return false; + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; } __host__ __device__ static constexpr index_t @@ -200,34 +246,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return grid_size; } - __host__ __device__ static constexpr auto - MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) { - constexpr auto xdlops_gemm = XdlopsGemm{}; + const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; - constexpr auto CLayout = xdlops_gemm.GetCLayout(); - - constexpr auto M0 = Number{}; - constexpr auto M1 = Number{}; - constexpr auto M2 = Number{}; - - constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); - constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); - - constexpr auto N1 = Number{}; - - const auto c_m0_m1_m2_n_grid_desc = transform_tensor_descriptor( - c_m_n_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)), - make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{})); - - return c_m0_m1_m2_n_grid_desc; + return has_main_k0_block_loop; } __host__ __device__ static constexpr auto - MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) + MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + using BlockwiseGemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + + return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01) { const auto M = c_m_n_grid_desc.GetLength(I0); const auto N = c_m_n_grid_desc.GetLength(I1); @@ -238,31 +316,40 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const auto M0 = M / M1; const auto N0 = N / N1; -#if 1 + const auto M00 = M0 / M01; + const auto N00 = N0 / N01; + + const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); + + const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), - make_tuple(Sequence<0, 1>{}), - make_tuple(Sequence<0>{})); -#elif 1 - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, M0))), - make_tuple(Sequence<1, 0>{}), - make_tuple(Sequence<0>{})); -#endif + chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor); return c_blockid_to_m0_n0_block_cluster_adaptor; } - using CM0M1M2NGridDesc = decltype(MakeCM0M1M2NGridDescriptor(CMNGridDesc{})); - using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); + using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1)); + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, FloatAB* __restrict__ p_shared_block, const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const CBlockClusterAdaptor& c_block_cluster_adaptor) { const auto a_grid_buf = make_dynamic_buffer( @@ -270,7 +357,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize()); + p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); @@ -289,20 +376,40 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto max_lds_align = K1; // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); // A matrix blockwise copy auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4, + Sequence, ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, @@ -328,7 +435,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4, + Sequence, BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, @@ -352,59 +459,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, MPerBlock] is in LDS - // b_mtx[KPerBlock, NPerBlock] is in LDS + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check - static_assert(MPerBlock % (MPerWave * MRepeat) == 0 && - NPerBlock % (NPerWave * NRepeat) == 0, - "wrong!"); + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - constexpr auto a_k0_m0_m1_k1_block_desc = transform_tensor_descriptor( - a_k0_m_k1_block_desc, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - constexpr auto b_k0_n0_n1_k1_block_desc = transform_tensor_descriptor( - b_k0_n_k1_block_desc, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - const auto blockwise_gemm = - BlockwiseGemmXdlops_km_kn_m0m1m2n_v1{}; - - constexpr auto CLayout = blockwise_gemm.GetCLayout(); - - constexpr index_t BlkSize = CLayout.GetBlkSize(); - constexpr index_t NumBlks = CLayout.GetNumBlks(); - constexpr index_t NumXdlops = CLayout.GetNumXdlops(); - - static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only"); - - constexpr auto c_mr_nr_blk_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); - - StaticBuffer, - c_mr_nr_blk_desc.GetElementSpaceSize(), - true> - c_thread_buf; + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size = @@ -413,8 +486,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 FloatAB* p_a_block = p_shared_block; FloatAB* p_b_block = p_shared_block + a_block_space_size; - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); // hack to control index calculation when iterating over A and B matrix for threadwise copy constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; @@ -440,32 +513,37 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 } // main body - index_t k_block_data_begin = 0; + index_t k0_block_data_begin = 0; - do + if constexpr(HasMainKBlockLoop) { - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, - a_block_slice_copy_step, - a_k0_m_k1_grid_move_slice_window_step_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc, - b_block_slice_copy_step, - b_k0_n_k1_grid_move_slice_window_step_hack); + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, + a_block_slice_copy_step, + a_k0_m_k1_grid_move_slice_window_step_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc, + b_block_slice_copy_step, + b_k0_n_k1_grid_move_slice_window_step_hack); - a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); + a_blockwise_copy.RunRead( + a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); - block_sync_lds(); + block_sync_lds(); - b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); + b_blockwise_copy.RunRead( + b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - block_sync_lds(); + block_sync_lds(); - a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); + a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); - k_block_data_begin += KPerBlock; - } while(k_block_data_begin < (K0 - KPerBlock)); + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); + } // tail { @@ -474,41 +552,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); } -#if 0 // output: register to global memory { - constexpr index_t M0 = CLayout.M1(); - constexpr index_t M1 = CLayout.N1(); - constexpr index_t M2 = CLayout.M0(); + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor(); - constexpr index_t N0 = CLayout.N1(); - constexpr index_t N1 = CLayout.N0(); + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); - constexpr auto c_m0_m1_m2_n_thread_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, - Number{}, - Number<1>{}, - Number<1>{}, - Number{}, - Number<1>{}, - Number{}, - Number<1>{})); - - StaticBuffer - c_blk_buf_; - - static_for<0, MRepeat, 1>{}([&](auto mr_i) { - static_for<0, NRepeat, 1>{}([&](auto nr_i) { - constexpr auto blk_off = - c_mr_nr_blk_desc.CalculateOffset(make_tuple(mr_i, nr_i)); - - static_for<0, BlkSize, 1>{}([&](auto j) { - c_blk_buf_(Number{}) = - c_thread_buf[Number{}] - .template AsType()[Number{}]; - }); - }); - }); + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, I1, I1, Number{}, I1, Number{}, I1)); // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index @@ -521,277 +581,57 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{}; + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; - constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); - constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); + const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); - ThreadwiseTensorSliceTransfer_v1r3< - FloatC, - FloatC, - decltype(c_m0_m1_m2_n_thread_desc), - decltype(c_m0_m1_m2_n_grid_desc), - Sequence, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{ - c_m0_m1_m2_n_grid_desc, - make_multi_index(m_thread_data_on_grid / (M2 * M1 * M0 * MWaves), - n_thread_data_on_grid / (N1 * NWaves), - m_thread_data_on_grid % (M2 * M1 * M0 * MWaves) / (M2 * M1 * M0), - n_thread_data_on_grid % (N1 * NWaves) / N1, - m_thread_data_on_grid % (M2 * M1 * M0) / (M2 * M1), - m_thread_data_on_grid % (M2 * M1) / M2, - m_thread_data_on_grid % M2, - n_thread_data_on_grid % N1)} - .Run(c_m0_m1_m2_n_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_blk_buf_, - c_m0_m1_m2_n_grid_desc, - c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); - } -#else - { - constexpr index_t M0 = CLayout.M1(); - constexpr index_t M1 = CLayout.N1(); - constexpr index_t M2 = CLayout.M0(); + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); - constexpr auto c_m0_m1_m2_n_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(I1, I1, I1, I1, Number{}, Number<1>{}, Number{}, Number<1>{})); + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_grid = - m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; - - const index_t n_thread_data_on_grid = - n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - - constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{}; + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); auto c_thread_copy = - ThreadwiseTensorSliceTransfer_v1r3, + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc), + Sequence, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, CGlobalMemoryDataOperation, 1, true>{ - c_m0_m1_m2_n_grid_desc, - make_multi_index(0, - 0, - 0, - 0, - m_thread_data_on_grid / (M2 * M1), - m_thread_data_on_grid % (M2 * M1) / M2, - m_thread_data_on_grid % M2, - n_thread_data_on_grid)}; - auto init_copy = [&](auto c_thread_idx_) { - constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, - c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3], + m_thread_data_on_grid_idx[I4], + n_thread_data_on_grid_idx[I2])}; - return c_thread_idx_; - }; - - auto mrepeat_plus_copy = [&](auto c_thread_idx_) { - constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus); - - constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, - c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); - }; - - auto nrepeat_plus_copy = [&](auto c_thread_idx_) { - constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_plus); - - constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, - c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); - }; - - auto mrepeat_minus_copy = [&](auto c_thread_idx_) { - constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus); - - constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, - c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); - }; - - auto nrepeat_minus_copy = [&](auto c_thread_idx_) { - constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_minus); - - constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, - c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); - }; - - static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or - (MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or - (MRepeat == 2 && NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or - (MRepeat == 1 && NRepeat == 1), - "wrong"); - - if constexpr(MRepeat == 4 && NRepeat == 4) - { - init_copy(make_tuple(I0, I0)); - - if constexpr(CAccessOrderMRepeatNRepeat) - { - nrepeat_plus_copy(make_tuple(I0, I1)); - nrepeat_plus_copy(make_tuple(I0, I2)); - nrepeat_plus_copy(make_tuple(I0, I3)); - mrepeat_plus_copy(make_tuple(I1, I3)); - nrepeat_minus_copy(make_tuple(I1, I2)); - nrepeat_minus_copy(make_tuple(I1, I1)); - nrepeat_minus_copy(make_tuple(I1, I0)); - mrepeat_plus_copy(make_tuple(I2, I0)); - nrepeat_plus_copy(make_tuple(I2, I1)); - nrepeat_plus_copy(make_tuple(I2, I2)); - nrepeat_plus_copy(make_tuple(I2, I3)); - mrepeat_plus_copy(make_tuple(I3, I3)); - nrepeat_minus_copy(make_tuple(I3, I2)); - nrepeat_minus_copy(make_tuple(I3, I1)); - nrepeat_minus_copy(make_tuple(I3, I0)); - } - else - { - mrepeat_plus_copy(make_tuple(I1, I0)); - mrepeat_plus_copy(make_tuple(I2, I0)); - mrepeat_plus_copy(make_tuple(I3, I0)); - nrepeat_plus_copy(make_tuple(I3, I1)); - mrepeat_minus_copy(make_tuple(I2, I1)); - mrepeat_minus_copy(make_tuple(I1, I1)); - mrepeat_minus_copy(make_tuple(I0, I1)); - nrepeat_plus_copy(make_tuple(I0, I2)); - mrepeat_plus_copy(make_tuple(I1, I2)); - mrepeat_plus_copy(make_tuple(I2, I2)); - mrepeat_plus_copy(make_tuple(I3, I2)); - nrepeat_plus_copy(make_tuple(I3, I3)); - mrepeat_minus_copy(make_tuple(I2, I3)); - mrepeat_minus_copy(make_tuple(I1, I3)); - mrepeat_minus_copy(make_tuple(I0, I3)); - } - } - else if constexpr(MRepeat == 4 && NRepeat == 2) - { - init_copy(make_tuple(I0, I0)); - - if constexpr(CAccessOrderMRepeatNRepeat) - { - nrepeat_plus_copy(make_tuple(I0, I1)); - mrepeat_plus_copy(make_tuple(I1, I1)); - nrepeat_minus_copy(make_tuple(I1, I0)); - mrepeat_plus_copy(make_tuple(I2, I0)); - nrepeat_plus_copy(make_tuple(I2, I1)); - mrepeat_plus_copy(make_tuple(I3, I1)); - nrepeat_minus_copy(make_tuple(I3, I0)); - } - else - { - mrepeat_plus_copy(make_tuple(I1, I0)); - mrepeat_plus_copy(make_tuple(I2, I0)); - mrepeat_plus_copy(make_tuple(I3, I0)); - nrepeat_plus_copy(make_tuple(I3, I1)); - mrepeat_minus_copy(make_tuple(I2, I1)); - mrepeat_minus_copy(make_tuple(I1, I1)); - mrepeat_minus_copy(make_tuple(I0, I1)); - } - } - else if constexpr(MRepeat == 2 && NRepeat == 4) - { - init_copy(make_tuple(I0, I0)); - - if constexpr(CAccessOrderMRepeatNRepeat) - { - nrepeat_plus_copy(make_tuple(I0, I1)); - nrepeat_plus_copy(make_tuple(I0, I2)); - nrepeat_plus_copy(make_tuple(I0, I3)); - mrepeat_plus_copy(make_tuple(I1, I3)); - nrepeat_minus_copy(make_tuple(I1, I2)); - nrepeat_minus_copy(make_tuple(I1, I1)); - nrepeat_minus_copy(make_tuple(I1, I0)); - } - else - { - mrepeat_plus_copy(make_tuple(I1, I0)); - nrepeat_plus_copy(make_tuple(I1, I1)); - mrepeat_minus_copy(make_tuple(I0, I1)); - nrepeat_plus_copy(make_tuple(I0, I2)); - mrepeat_plus_copy(make_tuple(I1, I2)); - nrepeat_plus_copy(make_tuple(I1, I3)); - mrepeat_minus_copy(make_tuple(I0, I3)); - } - } - else if constexpr(MRepeat == 2 && NRepeat == 2) - { - init_copy(make_tuple(I0, I0)); - - if constexpr(CAccessOrderMRepeatNRepeat) - { - nrepeat_plus_copy(make_tuple(I0, I1)); - mrepeat_plus_copy(make_tuple(I1, I1)); - nrepeat_minus_copy(make_tuple(I1, I0)); - } - else - { - mrepeat_plus_copy(make_tuple(I1, I0)); - nrepeat_plus_copy(make_tuple(I1, I1)); - mrepeat_minus_copy(make_tuple(I0, I1)); - } - } - else if constexpr(MRepeat == 2 && NRepeat == 1) - { - init_copy(make_tuple(I0, I0)); - mrepeat_plus_copy(make_tuple(I1, I0)); - } - else if constexpr(MRepeat == 1 && NRepeat == 2) - { - init_copy(make_tuple(I0, I0)); - nrepeat_plus_copy(make_tuple(I0, I1)); - } - else if constexpr(MRepeat == 1 && NRepeat == 1) - { - init_copy(make_tuple(I0, I0)); - } + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_grid_buf, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); } -#endif } }; // namespace ck diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp new file mode 100644 index 0000000000..8a9c932f4c --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp @@ -0,0 +1,666 @@ +#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP +#define CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "blockwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_set.hpp" + +namespace ck { + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const CBlockClusterAdaptor c_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); +} +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_b_k0_m_k1_grid_desc, + const void CONSTANT* p_b_b_k0_n_k1_grid_desc, + const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const void CONSTANT* p_c_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + const auto a_b_k0_m_k1_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_a_b_k0_m_k1_grid_desc)); + const auto b_b_k0_n_k1_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_b_b_k0_n_k1_grid_desc)); + const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = + *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc)); + const auto c_block_cluster_adaptor = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor)); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); +} +#endif + +template +struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ __device__ static constexpr bool + CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + index_t M01, + index_t N01) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); + const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); + + if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) && + K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) && + K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) && + KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0)) + return false; + + // check M01, N01 + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + if(!(M0 % M01 == 0 && N0 % N01 == 0)) + return false; + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock) * KBatch; + + return grid_size; + } + + __host__ __device__ static constexpr auto + MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + using BlockwiseGemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + + return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( + const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto M00 = M0 / M01; + const auto N00 = N0 / N01; + + const auto kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_pass_through_transform(KBatch), + make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); + + const auto c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(KBatch, M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto c_blockid_to_kbatch_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor); + + return c_blockid_to_kbatch_m0_n0_block_cluster_adaptor; + } + + using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); + + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const CBlockClusterAdaptor& c_block_cluster_adaptor) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); + + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t k_batch_id = block_work_idx[I0]; + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto a_b_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto b_b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_K0_M_K1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_b_k0_m_k1_grid_desc, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseTensorSliceTransfer_v4, + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_b_k0_n_k1_grid_desc, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = p_shared_block; + FloatAB* p_b_block = p_shared_block + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; + constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{}; + + // hack to control index calculation when move slice window for A and B matrix for + // threadwise copy + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{}; + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{}; + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // preload data into LDS + { + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); + + a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); + } + + // main body + index_t k_block_data_begin = 0; + + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, + a_block_slice_copy_step, + a_k0_m_k1_grid_move_slice_window_step_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, + b_block_slice_copy_step, + b_k0_n_k1_grid_move_slice_window_step_hack); + + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); + + k_block_data_begin += KPerBlock; + } while(k_block_data_begin < (K0 - KPerBlock)); + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, I1, I1, Number{}, I1, Number{}, I1)); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; + + const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3], + m_thread_data_on_grid_idx[I4], + n_thread_data_on_grid_idx[I2])}; + + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_grid_buf, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); + } + } +}; // namespace ck + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp b/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp index a925a5cd68..8b75381026 100644 --- a/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp @@ -55,19 +55,16 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 CBuffer& c_buf, COriginIdx) { - static_assert( - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value, - "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - "wrong! inconsistent type"); + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -157,19 +154,16 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_ CBuffer& c_buf, COriginIdx) { - static_assert( - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value, - "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - "wrong! inconsistent type"); + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp b/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp index 015ad675fb..f6c15fd85a 100644 --- a/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp @@ -41,19 +41,16 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3 CDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); - static_assert( - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value, - "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - "wrong! inconsistent type"); + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp index 0c7aa978a7..20e9a5b366 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp @@ -30,11 +30,11 @@ struct ThreadwiseTensorSliceSet_v1 static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); - static_assert(is_known_at_compile_time>>::value, + static_assert(is_known_at_compile_time>::value, "wrong! OriginIdx need to be known at compile-time"); // Desc is known at compile-time - constexpr auto desc = remove_cv_t>{}; + constexpr auto desc = remove_cvref_t{}; // OriginIdx is known at compile-time constexpr auto origin_idx = to_multi_index(OriginIdx{}); diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp index 27fd91812d..7e3f6b3489 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp @@ -95,18 +95,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3 static_assert(SrcDesc::IsKnownAtCompileTime(), "wrong! SrcDesc need to known at compile-time"); - static_assert( - is_known_at_compile_time>>::value, - "wrong! SrcSliceOrigin need to known at compile-time"); + static_assert(is_known_at_compile_time>::value, + "wrong! SrcSliceOrigin need to known at compile-time"); static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); - // static_assert(is_same>, - // remove_cv_t>>::value, - //"wrong! SrcBuffer data type is wrong"); - // SrcDesc and src_slice_origin_idx are known at compile-time - constexpr auto src_desc = remove_cv_t>{}; + constexpr auto src_desc = remove_cvref_t{}; constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); constexpr auto I0 = Number<0>{}; @@ -208,10 +203,20 @@ struct ThreadwiseTensorSliceTransfer_v1r3 coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); // copy data from dst_vector into dst_buf - dst_buf.template Set( - dst_coord_.GetOffset(), - is_dst_valid, - dst_vector.template AsType()[Number<0>{}]); + if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set) + { + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + } + else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd) + { + dst_buf.template AtomicAdd( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + } constexpr auto move_on_dim = [&]() constexpr { @@ -392,7 +397,7 @@ struct ThreadwiseTensorSliceTransfer_v2 "wrong! SrcDesc need to known at compile-time"); } - __device__ void SetDstSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) { src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); } @@ -411,16 +416,15 @@ struct ThreadwiseTensorSliceTransfer_v2 static_assert(DstDesc::IsKnownAtCompileTime(), "wrong! DstDesc need to known at compile-time"); - static_assert( - is_known_at_compile_time>>::value, - "wrong! DstSliceOrigin need to known at compile-time"); + static_assert(is_known_at_compile_time>::value, + "wrong! DstSliceOrigin need to known at compile-time"); - static_assert(is_same>, - remove_cv_t>>::value && - "wrong! inconsistent type"); + static_assert( + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); // DstDesc and dst_slice_origin_idx are known at compile-time - constexpr auto dst_desc = remove_cv_t>{}; + constexpr auto dst_desc = remove_cvref_t{}; constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; constexpr auto I0 = Number<0>{}; @@ -729,9 +733,9 @@ struct ThreadwiseTensorSliceTransfer_v3 SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, "wrong!"); - static_assert(is_same>, - remove_cv_t>>::value, - "wrong! SrcBuffer and SrcData data type are inconsistent"); + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -886,9 +890,9 @@ struct ThreadwiseTensorSliceTransfer_v3 DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, "wrong!"); - static_assert(is_same>, - remove_cv_t>>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -1303,24 +1307,21 @@ struct ThreadwiseTensorSliceTransfer_v4 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), "wrong! SrcDesc and DstDesc need to known at compile-time"); - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); - static_assert( - is_known_at_compile_time< - remove_cv_t>>::value && - is_known_at_compile_time>>::value, - "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " - "at compile-time"); + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); // SrcDesc and DstDesc are known at compile-time - constexpr auto src_desc = remove_cv_t>{}; - constexpr auto dst_desc = remove_cv_t>{}; + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp index ccac4b7b44..bbdaa5fa2b 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp @@ -80,9 +80,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, "wrong!"); - static_assert(is_same>, - remove_cv_t>>::value, - "wrong! SrcBuffer and SrcData data type are inconsistent"); + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); // tensor descriptor for src_vector constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; @@ -248,9 +248,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, "wrong!"); - static_assert(is_same>, - remove_cv_t>>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); // tensor descriptor for dst_vector constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{}; @@ -669,24 +669,21 @@ struct ThreadwiseTensorSliceTransfer_v4r1 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), "wrong! SrcDesc and DstDesc need to known at compile-time"); - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); - static_assert( - is_known_at_compile_time< - remove_cv_t>>::value && - is_known_at_compile_time>>::value, - "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " - "at compile-time"); + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); // SrcDesc and DstDesc are known at compile-time - constexpr auto src_desc = remove_cv_t>{}; - constexpr auto dst_desc = remove_cv_t>{}; + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp index affe096ace..10633f8f32 100644 --- a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp +++ b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp @@ -7,21 +7,18 @@ namespace ck { -enum struct mfma_instr +enum struct MfmaInstr { - /// fp32 mfma_f32_32x32x1xf32 = 0, mfma_f32_16x16x1xf32, mfma_f32_4x4x1xf32, mfma_f32_32x32x2xf32, // k reduction mfma_f32_16x16x4xf32, // k reduction - /// fp16 mfma_f32_32x32x4f16, mfma_f32_16x16x4f16, mfma_f32_4x4x4f16, mfma_f32_32x32x8f16, // k reduction mfma_f32_16x16x16f16, // k reduction - /// bfp16 mfma_f32_32x32x2bf16, mfma_f32_16x16x2bf16, mfma_f32_4x4x2bf16, @@ -29,317 +26,245 @@ enum struct mfma_instr mfma_f32_16x16x8bf16, // k reduction }; -template -struct mfma_info; +template +struct mfma_type; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 2; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 1; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 2; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = false; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_32x32x1f32::Run(a, b, reg_c); + intrin_mfma_f32_32x32x1f32::Run(a, b, reg_c); } }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 2; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = true; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_32x32x2f32::Run(a, b, reg_c); + intrin_mfma_f32_32x32x2f32::Run(a, b, reg_c); } }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 4; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = true; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_16x16x4f32::Run(a, b, reg_c); + intrin_mfma_f32_16x16x4f32::Run(a, b, reg_c); } }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 4; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 1; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 4; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = false; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_16x16x1f32::Run(a, b, reg_c); + intrin_mfma_f32_16x16x1f32::Run(a, b, reg_c); } }; // treat 4x4x1 as a single-blk 4x64 mfma template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 64; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = 1; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = 4; - static constexpr index_t m = 4; - static constexpr index_t n = 64; - static constexpr index_t k = 1; - static constexpr index_t cycles = 8; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 4; + static constexpr index_t n_per_blk = 64; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = false; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_4x4x1f32::Run(a, b, reg_c); + intrin_mfma_f32_4x4x1f32::Run(a, b, reg_c); } }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 2; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 4; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 2; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = false; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_32x32x4f16::Run(a, b, reg_c); + intrin_mfma_f32_32x32x4f16::Run(a, b, reg_c); } }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 8; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_32x32x8f16::Run(a, b, reg_c); + intrin_mfma_f32_32x32x8f16::Run(a, b, reg_c); } }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 16; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_16x16x16f16::Run(a, b, reg_c); + intrin_mfma_f32_16x16x16f16::Run(a, b, reg_c); } }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 4; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 4; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 4; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = false; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_16x16x4f16::Run(a, b, reg_c); + intrin_mfma_f32_16x16x4f16::Run(a, b, reg_c); } }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 64; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = 1; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = 4; - static constexpr index_t m = 4; - static constexpr index_t n = 64; - static constexpr index_t k = 4; - static constexpr index_t cycles = 8; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 4; + static constexpr index_t n_per_blk = 64; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = false; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_4x4x4f16::Run(a, b, reg_c); + intrin_mfma_f32_4x4x4f16::Run(a, b, reg_c); } }; #if 0 template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 2; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 2; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 2; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 4; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 8; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 4; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 2; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 4; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 64; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = 1; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = 4; - static constexpr index_t m = 4; - static constexpr index_t n = 64; - static constexpr index_t k = 2; - static constexpr index_t cycles = 8; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 4; + static constexpr index_t n_per_blk = 64; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = false; template }; #endif -template -struct xdlops_info +template +struct MfmaSelector { - static constexpr auto mfma_type = mfma_info{}; + template + static constexpr auto GetMfma(); - static constexpr index_t MPerXdlops = MPerXdlops_; - static constexpr index_t NPerXdlops = NPerXdlops_; + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x2xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x4xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x8f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x16f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x4f16; + } + +#if 0 + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } +#endif + + static constexpr auto selected_mfma = mfma_type()>{}; + + __host__ __device__ static constexpr void mfma_check() + { + static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk == + selected_mfma.num_regs_per_blk, + "wrong! num_regs_per_blk"); + + static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk, + "n_per_blk != num_threads_per_blk"); + + static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks == + selected_mfma.m_per_blk, + "m_per_blk != num_input_blks * num_regs_per_blk"); + + static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks || + selected_mfma.num_output_blks == 1, + "incorrect num_output_blks"); + + static_assert(selected_mfma.num_regs_per_blk * selected_mfma.wave_size == + selected_mfma.m_per_blk * selected_mfma.n_per_blk, + "num_regs_per_blk incorrect"); + + static_assert(selected_mfma.is_k_reduction || + (selected_mfma.num_input_blks == selected_mfma.num_output_blks), + "is_k_reduction wrong!"); + } + + __host__ __device__ constexpr MfmaSelector() { mfma_check(); } static constexpr bool IsABroadcast() { @@ -505,186 +602,33 @@ struct xdlops_info return true; } - static constexpr bool IsKReduction() - { - return (mfma_type.num_output_blks == 1) && (mfma_type.num_input_blks > 1); - } - static constexpr index_t GetKPerXdlops() { - return IsKReduction() ? mfma_type.num_input_blks : 1; + return (selected_mfma.is_k_reduction ? selected_mfma.num_input_blks : 1) * + selected_mfma.k_per_blk; } - static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; } + static constexpr index_t GetKPerThread() { return selected_mfma.k_per_blk; } }; -template +template struct XdlopsGemm { - template - static constexpr auto GetXdlopsInfo(); - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - -#if 0 - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } -#endif + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; using CIndex = MultiIndex<2>; - __device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; } + __device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; } __device__ static constexpr index_t GetNumXdlops() { - return MPerXdlops * NPerXdlops / (mfma_type.m * mfma_type.n * mfma_type.num_output_blks); + return MPerXdlops * NPerXdlops / + (mfma_instr.m_per_blk * mfma_instr.n_per_blk * mfma_instr.num_output_blks); } __host__ __device__ constexpr XdlopsGemm() @@ -697,104 +641,142 @@ struct XdlopsGemm MPerXdlops == 64, "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); - static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk"); - static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m, - "m != num_input_blks * num_regs_blk"); - static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks || - mfma_type.num_output_blks == 1, - "incorrect num_output_blks"); - static_assert(mfma_type.num_regs_blk * mfma_type.wave_size == mfma_type.m * mfma_type.n, - "num_regs_blk incorrect"); + static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk"); + } - static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!"); + template + __host__ __device__ static constexpr auto + MakeCM0N0M1N1M2M3M4N2Descriptor(const CM0N0M1N1M2N2Desc& c_m0_n0_m1_n1_m2_n2_desc) + { + const auto M0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I0); + const auto N0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I1); + const auto M1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I2); + const auto N1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I3); + + return transform_tensor_descriptor( + c_m0_n0_m1_n1_m2_n2_desc, + make_tuple(make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk, + mfma_instr.num_input_blks, + mfma_instr.group_size)), + make_pass_through_transform(mfma_instr.num_threads_per_blk)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4, 5, 6>{}, + Sequence<7>{})); } __device__ static constexpr index_t GetRegSizePerXdlops() { - return MPerXdlops * NPerXdlops / mfma_type.wave_size; + return MPerXdlops * NPerXdlops / mfma_instr.wave_size; } - template + template __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { static_assert(is_same::value || is_same::value || is_same::value, "base base_type must be float, half, ushort!"); - static_assert(KPack % mfma_type.k_base == 0, "KPack cannot be divided by k_base"); - - constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops(); - - static_for<0, KPack, mfma_type.k_base>{}([&](auto k) { - constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k)); - constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k)); - - mfma_type.template run( - p_a_wave[Number{}], - p_b_wave[Number{}], - p_c_thread); + static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { + mfma_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); }); } + __device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; } + + __device__ static auto GetBlkIdx() + { + const auto laneId = GetLaneId(); + + constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform( + make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto blk_idx = + threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId)); + + const auto blk_id = blk_idx[I1]; + const auto blk_td = blk_idx[I2]; + + return make_tuple(blk_id, blk_td); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(mfma_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(mfma_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) { - const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size; - const index_t blk_id = laneId / mfma_type.num_threads_blk; - const index_t blk_td = laneId % mfma_type.num_threads_blk; + const auto blk_idx = GetBlkIdx(); - index_t n_offset = blk_i * mfma_type.n + blk_td; - index_t m_offset = xdlops_i * mfma_type.m + blk_id * mfma_type.group_size; + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + index_t n_offset = blk_i * mfma_instr.n_per_blk + blk_td; + index_t m_offset = xdlops_i * mfma_instr.m_per_blk + blk_id * mfma_instr.group_size; return CIndex{m_offset, n_offset}; } - static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats; - static constexpr index_t NRepeats = GetXdlopsInfo().NRepeats; - static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops; - static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops; + static constexpr auto mfma = MfmaSelector{}; - static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction(); - static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast(); - static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops(); + static constexpr auto mfma_instr = mfma.selected_mfma; - static constexpr auto GetBlkId(const index_t lane_id) + static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); + static constexpr auto K1PerXdlops = mfma.GetKPerThread(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; + + __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() { - return lane_id / mfma_type.num_threads_blk; + return make_tuple( + Number{}, I1, Number{}, I1); } - - static constexpr auto GetBlkTd(const index_t lane_id) - { - return lane_id % mfma_type.num_threads_blk; - } - - static constexpr auto mfma_type = GetXdlopsInfo().mfma_type; - - struct CLayout - { - __host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; } - __host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; } - __host__ __device__ static constexpr index_t N1() { return mfma_type.num_input_blks; } - __host__ __device__ static constexpr index_t N0() { return mfma_type.num_threads_blk; } - - __device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; } - - __device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; } - - __device__ static constexpr index_t GetNumXdlops() - { - return MPerXdlops * NPerXdlops / - (mfma_type.m * mfma_type.n * mfma_type.num_output_blks); - } - }; - - __host__ __device__ static constexpr auto GetCLayout() { return CLayout{}; } }; } // namespace ck diff --git a/composable_kernel/include/utility/amd_buffer_addressing.hpp b/composable_kernel/include/utility/amd_buffer_addressing.hpp index a54607a053..3df53bda44 100644 --- a/composable_kernel/include/utility/amd_buffer_addressing.hpp +++ b/composable_kernel/include/utility/amd_buffer_addressing.hpp @@ -202,6 +202,22 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); +// atomic add +// int +__device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( + int32_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32"); + +// float +__device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( + float vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32"); template __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, @@ -624,8 +640,130 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src } } +template +__device__ void amd_buffer_atomic_add_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert((is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4)), + "wrong! not implemented"); + + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_add_fp32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(float), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(float), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(float), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(float), + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_add_i32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(int32_t), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(int32_t), + 0); + } + } +} + // buffer_load requires: -// 1) p_src_wave must be in global memory space +// 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template @@ -659,7 +797,7 @@ amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave, } // buffer_load requires: -// 1) p_src_wave must be in global memory space +// 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template @@ -687,8 +825,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, } // buffer_store requires: -// 1) p_dst_wave must be global memory -// 2) p_dst_wave to be a wavewise pointer. +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template __device__ void amd_buffer_store(const typename vector_type_maker::type::type src_thread_data, @@ -720,5 +858,40 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t #endif } +// buffer_atomic_add requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ void +amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; + + amd_buffer_atomic_add_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_atomic_add_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + } // namespace ck #endif diff --git a/composable_kernel/include/utility/amd_xdlops.hpp b/composable_kernel/include/utility/amd_xdlops.hpp index da74fe1d48..083e47fbf1 100644 --- a/composable_kernel/include/utility/amd_xdlops.hpp +++ b/composable_kernel/include/utility/amd_xdlops.hpp @@ -51,304 +51,196 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16( extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16( ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16"); -template +template struct intrin_mfma_f32_32x32x1f32; -template -struct intrin_mfma_f32_32x32x1f32<64, 64, COffset> +template <> +struct intrin_mfma_f32_32x32x1f32<64, 64> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 1, - 0, - 0); - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 1, - 1, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); } }; -template -struct intrin_mfma_f32_32x32x1f32<32, 64, COffset> +template <> +struct intrin_mfma_f32_32x32x1f32<32, 64> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 1, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); } }; -template +template struct intrin_mfma_f32_32x32x2f32; -template -struct intrin_mfma_f32_32x32x2f32<32, 32, COffset> +template <> +struct intrin_mfma_f32_32x32x2f32<32, 32> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x2f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 0, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x2f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; -template +template struct intrin_mfma_f32_16x16x4f32; -template -struct intrin_mfma_f32_16x16x4f32<16, 16, COffset> +template <> +struct intrin_mfma_f32_16x16x4f32<16, 16> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_16x16x4f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 0, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; -template +template struct intrin_mfma_f32_16x16x1f32; -template -struct intrin_mfma_f32_16x16x1f32<16, 64, COffset> +template <> +struct intrin_mfma_f32_16x16x1f32<16, 64> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_16x16x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 2, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); } }; -template +template struct intrin_mfma_f32_4x4x1f32; -template -struct intrin_mfma_f32_4x4x1f32<4, 64, COffset> +template <> +struct intrin_mfma_f32_4x4x1f32<4, 64> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_4x4x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 4, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); } }; -template -struct intrin_mfma_f32_4x4x1f32<8, 64, COffset> +template <> +struct intrin_mfma_f32_4x4x1f32<8, 64> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_4x4x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 4, - 0, - 0); - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_4x4x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 4, - 1, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); } }; -template +template struct intrin_mfma_f32_32x32x4f16; -template -struct intrin_mfma_f32_32x32x4f16<64, 64, COffset> +template <> +struct intrin_mfma_f32_32x32x4f16<64, 64> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 1, - 0, - 0); - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 1, - 1, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); } }; -template -struct intrin_mfma_f32_32x32x4f16<32, 64, COffset> +template <> +struct intrin_mfma_f32_32x32x4f16<32, 64> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 1, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); } }; -template +template struct intrin_mfma_f32_32x32x8f16; -template -struct intrin_mfma_f32_32x32x8f16<32, 32, COffset> +template <> +struct intrin_mfma_f32_32x32x8f16<32, 32> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x8f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 0, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x8f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; -template +template struct intrin_mfma_f32_16x16x16f16; -template -struct intrin_mfma_f32_16x16x16f16<16, 16, COffset> +template <> +struct intrin_mfma_f32_16x16x16f16<16, 16> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_16x16x16f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 0, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x16f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; -template +template struct intrin_mfma_f32_16x16x4f16; -template -struct intrin_mfma_f32_16x16x4f16<16, 64, COffset> +template <> +struct intrin_mfma_f32_16x16x4f16<16, 64> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_16x16x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 2, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); } }; -template +template struct intrin_mfma_f32_4x4x4f16; -template -struct intrin_mfma_f32_4x4x4f16<4, 64, COffset> +template <> +struct intrin_mfma_f32_4x4x4f16<4, 64> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_4x4x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 4, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); } }; -template -struct intrin_mfma_f32_4x4x4f16<8, 64, COffset> +template <> +struct intrin_mfma_f32_4x4x4f16<8, 64> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_4x4x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 4, - 0, - 0); - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_4x4x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 4, - 1, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); } }; @@ -448,7 +340,6 @@ template __device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec16_1_t::VecType reg_c); - template <> __device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a, const ushort2_t* reg_b, diff --git a/composable_kernel/include/utility/array.hpp b/composable_kernel/include/utility/array.hpp index 7271094d39..911cefd057 100644 --- a/composable_kernel/include/utility/array.hpp +++ b/composable_kernel/include/utility/array.hpp @@ -48,7 +48,7 @@ struct Array template __host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) { - using data_type = remove_cv_t>; + using data_type = remove_cvref_t; return Array{{std::forward(x), std::forward(xs)...}}; } diff --git a/composable_kernel/include/utility/config.hpp b/composable_kernel/include/utility/config.hpp index 521ad24d47..5ee4bb9c64 100644 --- a/composable_kernel/include/utility/config.hpp +++ b/composable_kernel/include/utility/config.hpp @@ -85,13 +85,13 @@ #define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 #endif -#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK -#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK 1 +#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK +#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1 #endif // pass tensor descriptor by value or void* -#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0 -#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1 +#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1 +#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0 // merge transformation use magic number division #define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0 diff --git a/composable_kernel/include/utility/dynamic_buffer.hpp b/composable_kernel/include/utility/dynamic_buffer.hpp index 210c493602..886737efac 100644 --- a/composable_kernel/include/utility/dynamic_buffer.hpp +++ b/composable_kernel/include/utility/dynamic_buffer.hpp @@ -43,18 +43,15 @@ struct DynamicBuffer __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } template >>::type, - typename scalar_type>>::type>::value, - bool>::type = false> + typename enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const { // X contains multiple T - constexpr index_t scalar_per_t_vector = - scalar_type>>::vector_size; + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; - constexpr index_t scalar_per_x_vector = - scalar_type>>::vector_size; + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X need to be multiple T"); @@ -71,15 +68,14 @@ struct DynamicBuffer if constexpr(InvalidElementUseNumericalZeroValue) { - return amd_buffer_load_invalid_element_return_return_zero< - remove_cv_t>, - t_per_x>(p_data_, i, is_valid_element, element_space_size_); + return amd_buffer_load_invalid_element_return_return_zero, + t_per_x>( + p_data_, i, is_valid_element, element_space_size_); } else { - return amd_buffer_load_invalid_element_return_customized_value< - remove_cv_t>, - t_per_x>( + return amd_buffer_load_invalid_element_return_customized_value, + t_per_x>( p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); } } @@ -98,18 +94,15 @@ struct DynamicBuffer } template >>::type, - typename scalar_type>>::type>::value, - bool>::type = false> + typename enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) { // X contains multiple T - constexpr index_t scalar_per_t_vector = - scalar_type>>::vector_size; + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; - constexpr index_t scalar_per_x_vector = - scalar_type>>::vector_size; + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X need to be multiple T"); @@ -119,7 +112,7 @@ struct DynamicBuffer #if CK_USE_AMD_BUFFER_ADDRESSING constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; - amd_buffer_store>, t_per_x>( + amd_buffer_store, t_per_x>( x, p_data_, i, is_valid_element, element_space_size_); #else if(is_valid_element) @@ -140,70 +133,65 @@ struct DynamicBuffer // ISA, so I try to let compiler emit IR "store" which would be lower to // ds_write_b128 // TODO: remove this after compiler fix - if constexpr(is_same>>::type, - int8_t>::value) + if constexpr(is_same>::type, int8_t>::value) { - static_assert( - (is_same>, int8_t>::value && - is_same>, int8_t>::value) || - (is_same>, int8_t>::value && - is_same>, int8x2_t>::value) || - (is_same>, int8_t>::value && - is_same>, int8x4_t>::value) || - (is_same>, int8x4_t>::value && - is_same>, int8x4_t>::value) || - (is_same>, int8x8_t>::value && - is_same>, int8x8_t>::value) || - (is_same>, int8x16_t>::value && - is_same>, int8x16_t>::value), - "wrong! not implemented for this combination, please add " - "implementation"); + static_assert((is_same, int8_t>::value && + is_same, int8_t>::value) || + (is_same, int8_t>::value && + is_same, int8x2_t>::value) || + (is_same, int8_t>::value && + is_same, int8x4_t>::value) || + (is_same, int8x4_t>::value && + is_same, int8x4_t>::value) || + (is_same, int8x8_t>::value && + is_same, int8x8_t>::value) || + (is_same, int8x16_t>::value && + is_same, int8x16_t>::value), + "wrong! not implemented for this combination, please add " + "implementation"); - if constexpr(is_same>, int8_t>::value && - is_same>, int8_t>::value) + if constexpr(is_same, int8_t>::value && + is_same, int8_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same>, int8_t>::value && - is_same>, int8x2_t>::value) + else if constexpr(is_same, int8_t>::value && + is_same, int8x2_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same>, int8_t>::value && - is_same>, int8x4_t>::value) + else if constexpr(is_same, int8_t>::value && + is_same, int8x4_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same>, - int8x4_t>::value && - is_same>, int8x4_t>::value) + else if constexpr(is_same, int8x4_t>::value && + is_same, int8x4_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same>, - int8x8_t>::value && - is_same>, int8x8_t>::value) + else if constexpr(is_same, int8x8_t>::value && + is_same, int8x8_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same>, - int8x16_t>::value && - is_same>, int8x16_t>::value) + else if constexpr(is_same, int8x16_t>::value && + is_same, int8x16_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix @@ -227,6 +215,35 @@ struct DynamicBuffer } } + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X need to be multiple T"); + + static_assert(GetAddressSpace() == AddressSpaceEnum_t::Global, "only support global mem"); + +#if CK_USE_AMD_BUFFER_ADDRESSING + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_atomic_add, t_per_x>( + x, p_data_, i, is_valid_element, element_space_size_); +#else + if(is_valid_element) + { + atomicAdd(&p_data_[i], x); + } +#endif + } + __host__ __device__ static constexpr bool IsStaticBuffer() { return false; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } diff --git a/composable_kernel/include/utility/magic_division.hpp b/composable_kernel/include/utility/magic_division.hpp index b7489016e9..612aceea2a 100644 --- a/composable_kernel/include/utility/magic_division.hpp +++ b/composable_kernel/include/utility/magic_division.hpp @@ -114,12 +114,11 @@ struct MagicDivision __host__ __device__ static constexpr uint32_t DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) { - uint32_t tmp = (uint64_t(dividend) * uint64_t(multiplier)) >> 32; + uint32_t tmp = __umulhi(dividend, multiplier); return (tmp + dividend) >> shift; } -#if 1 // debug - // HACK: magic division for int32_t + // magic division for int32_t // HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be // non-negative for result to be correct // TODO: figure out how to do magic number divison for int32_t as dividended @@ -127,27 +126,9 @@ struct MagicDivision DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) { uint32_t dividend_u32 = as_type(dividend_i32); - uint32_t tmp = - (static_cast(dividend_u32) * static_cast(multiplier)) >> 32; + uint32_t tmp = __umulhi(dividend_u32, multiplier); return (tmp + dividend_u32) >> shift; } -#else - // the inline ASM is producing wrong result - __host__ __device__ static int32_t - DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) - { - uint32_t r; - asm volatile("\n \ - v_mul_hi_u32 %0, %1, %2 \n \ - v_add_u32_e32 %0, %1, %0 \n \ - v_lshrrev_b32_e32 %0, %3, %0 \n \ - " - : "=v"(r) - : "v"(as_type(dividend_i32)), "s"(multiplier), "s"(shift)); - - return as_type(r); - } -#endif }; } // namespace ck diff --git a/composable_kernel/include/utility/static_buffer.hpp b/composable_kernel/include/utility/static_buffer.hpp index cd67b8a0be..9615d10c59 100644 --- a/composable_kernel/include/utility/static_buffer.hpp +++ b/composable_kernel/include/utility/static_buffer.hpp @@ -55,6 +55,98 @@ struct StaticBuffer : public StaticallyIndexedArray __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } }; +template +struct StaticBufferV2 : public StaticallyIndexedArray +{ + using type = T; + using base = StaticallyIndexedArray; + + using VecBaseType = typename T::d1_t; + + __host__ __device__ static constexpr index_t GetVectorSize() + { + return sizeof(typename T::type) / sizeof(VecBaseType); + } + + static constexpr index_t vector_size = GetVectorSize(); + + VecBaseType invalid_element_value_ = VecBaseType{0}; + + T invalid_vec_value_ = T{0}; + + __host__ __device__ constexpr StaticBufferV2() : base{} {} + + __host__ __device__ constexpr StaticBufferV2(VecBaseType invalid_element_value) + : base{}, + invalid_vec_value_{invalid_element_value}, + invalid_element_value_{invalid_element_value} + { + } + + __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() + { + return BufferAddressSpace; + } + + template + __host__ __device__ constexpr auto& GetVector(Number vec_id) + { + return this->At(vec_id); + } + + template + __host__ __device__ constexpr const auto& GetVector(Number vec_id) const + { + return this->At(vec_id); + } + + template + __host__ __device__ constexpr auto& GetElement(Number i, bool) + { + constexpr auto vec_id = Number{}; + constexpr auto vec_off = Number{}; + + return this->At(vec_id).template AsType()(vec_off); + } + + template + __host__ __device__ constexpr auto GetElement(Number i, bool is_valid_element) const + { + constexpr auto vec_id = Number{}; + constexpr auto vec_off = Number{}; + + if constexpr(InvalidElementUseNumericalZeroValue) + { + return is_valid_element ? this->At(vec_id).template AsType()[vec_off] + : VecBaseType{0}; + } + else + { + return is_valid_element ? this->At(vec_id).template AsType()[vec_off] + : invalid_element_value_; + } + } + + template + __host__ __device__ constexpr auto operator[](Number i) const + { + return GetElement(i, true); + } + + template + __host__ __device__ constexpr auto& operator()(Number i) + { + return GetElement(i, true); + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } + + __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } +}; + template __host__ __device__ constexpr auto make_static_buffer(Number) { diff --git a/composable_kernel/include/utility/tuple.hpp b/composable_kernel/include/utility/tuple.hpp index ee96a8b435..70f4d77d87 100644 --- a/composable_kernel/include/utility/tuple.hpp +++ b/composable_kernel/include/utility/tuple.hpp @@ -159,7 +159,7 @@ struct Tuple : detail::TupleImpl __host__ __device__ constexpr auto make_tuple(Xs&&... xs) { - return Tuple>...>(std::forward(xs)...); + return Tuple...>(std::forward(xs)...); } } // namespace ck diff --git a/composable_kernel/include/utility/tuple_helper.hpp b/composable_kernel/include/utility/tuple_helper.hpp index 9499a3596c..55a79d2594 100644 --- a/composable_kernel/include/utility/tuple_helper.hpp +++ b/composable_kernel/include/utility/tuple_helper.hpp @@ -14,9 +14,7 @@ struct is_known_at_compile_time> return container_reduce( Tuple{}, [](auto x, bool r) { - return is_known_at_compile_time< - remove_cv_t>>::value & - r; + return is_known_at_compile_time>::value & r; }, true); } diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp index c1208ac3cb..71239e0ecc 100644 --- a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp +++ b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp @@ -374,13 +374,8 @@ extern "C" __global__ void CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{}, CGridBlockCluster_BlockId_To_GM10_GN10{})); - const auto desc_tuple = *reinterpret_cast( -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wold-style-cast" - // TODO: how to cast? - (const void*)p_desc_tuple -#pragma clang diagnostic pop - ); + const auto desc_tuple = + *reinterpret_cast(cast_pointer_to_generic_address_space(p_desc_tuple)); const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0]; const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1]; diff --git a/host/driver_offline/CMakeLists.txt b/host/driver_offline/CMakeLists.txt index fec11e99af..a3b3613293 100644 --- a/host/driver_offline/CMakeLists.txt +++ b/host/driver_offline/CMakeLists.txt @@ -13,9 +13,15 @@ include_directories(BEFORE set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp) set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp) +set(CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp) +set(GEMM_DRIVER_OFFLINE_SOURCE src/gemm_driver_offline.cpp) add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE}) add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE}) +add_executable(conv_wrw_driver_offline ${CONV_WRW_DRIVER_OFFLINE_SOURCE}) +add_executable(gemm_driver_offline ${GEMM_DRIVER_OFFLINE_SOURCE}) target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor) target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor) +target_link_libraries(conv_wrw_driver_offline PRIVATE host_tensor) +target_link_libraries(gemm_driver_offline PRIVATE host_tensor) diff --git a/host/driver_offline/include/debug.hpp b/host/driver_offline/include/debug.hpp new file mode 100644 index 0000000000..72fd0763ba --- /dev/null +++ b/host/driver_offline/include/debug.hpp @@ -0,0 +1,13 @@ +#ifndef DEBUG_HPP +#define DEBUG_HPP + +namespace debug { +namespace debug_driver_gemm_xdlops_v2r3 { + +// these vars are on host, they control block_id to C matrix tile idx (m0, n0) mapping +static ck::index_t M01 = 1; +static ck::index_t N01 = 1; + +} // namespace debug_driver_gemm_xdlops_v2r3 +} // namespace debug +#endif diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp index 7bd82bf6d5..8258aa0e66 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp @@ -3,6 +3,7 @@ #include "host_tensor.hpp" #include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp" #include "driver_gemm_xdlops_v2r3.hpp" +#include "debug.hpp" template ; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; #endif @@ -208,40 +181,42 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( // HACK: hacks that control index calculation when iterating over A, B, C matrix constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1 constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1 - constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple( + // clang-format off + constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 3+: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 7+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 3-: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + //clang-format on constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; @@ -263,8 +238,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( GemmMPerBlock, GemmNPerBlock, GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, + GemmMPerXDL, + GemmNPerXDL, GemmK1, MRepeat, NRepeat, @@ -289,19 +264,23 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( GemmCThreadTransferDstScalarPerVector, decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(in_m0_m1_m2_n_grid_step_hacks), + decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat + false, // CAccessOrderMRepeatNRepeat + false, // ABlockLdsExtraM + false // BBlockLdsExtraN >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), wei_gemmk0_gemmm_gemmk1_grid_desc, out_gemmk0_gemmn_gemmk1_grid_desc, in_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, wei_gemmk0_gemmm_gemmk1_grid_step_hacks, out_gemmk0_gemmn_gemmk1_grid_step_hacks, - in_m0_m1_m2_n_grid_step_hacks, + in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp index 0ebf8571f4..28d6226f1b 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -49,7 +49,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); #if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + // [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 256; @@ -77,7 +77,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -104,8 +104,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 +#elif 0 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 256; @@ -133,7 +133,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 1 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -159,25 +159,93 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #endif - const auto descs = - transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(out_n_ho_wo_k_desc, - wei_k_y_x_c_desc, - in_n_hi_wi_c_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - I0, - I0, - Number{}); - - const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto in_gemmm_gemmn_grid_desc = descs[I2]; - // HACK: hacks that control index calculation when iterating over A, B, C matrix constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 @@ -185,7 +253,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: + // gemmk1 constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 @@ -195,25 +264,27 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 - constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple( + // clang-format off + constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 2+: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 4+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 5+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 6+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 2-: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 4-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 5-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + // clang-format on constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; @@ -223,64 +294,110 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk for(index_t i = 0; i < 5; ++i) { - float ave_time = driver_gemm_xdlops_v2r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum_t::Set, - decltype(out_gemmk0_gemmm_gemmk1_grid_desc), - decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), - decltype(in_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<2, 0, 1>, - Sequence<0, 2, 1>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilda = ConvStrideH / GcdStrideDilationH; + const auto XTilda = ConvStrideW / GcdStrideDilationW; + + float ave_time = 0; + + for(index_t i_ytilda = 0; i_ytilda < YTilda; ++i_ytilda) + { + for(index_t i_xtilda = 0; i_xtilda < XTilda; ++i_xtilda) + { + const auto descs = + transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( + out_n_ho_wo_k_desc, + wei_k_y_x_c_desc, + in_n_hi_wi_c_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + i_ytilda, + i_xtilda, + Number{}); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto in_gemmm_gemmn_grid_desc = descs[I2]; + + const auto GemmK0 = out_gemmk0_gemmm_gemmk1_grid_desc.GetLength(I0); + + if(GemmK0 != 0) + { + ave_time += driver_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(out_gemmk0_gemmm_gemmk1_grid_desc), + decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), + decltype(in_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<2, 0, 1>, + Sequence<0, 2, 1>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy #if 0 - Sequence<0, 2, 4, 5, 6, 1, 3, 7>, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, #else - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, #endif - 7, - GemmCThreadTransferDstScalarPerVector, - decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(in_m0_m1_m2_n_grid_step_hacks), - decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - true // CAccessOrderMRepeatNRepeat - >(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - out_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - in_gemmm_gemmn_grid_desc, - out_gemmk0_gemmm_gemmk1_grid_step_hacks, - wei_gemmk0_gemmn_gemmk1_grid_step_hacks, - in_m0_m1_m2_n_grid_step_hacks, - out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + true, // CAccessOrderMRepeatNRepeat + false, // ABlockLdsExtraM + false // BBlockLdsExtraN + >(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_step_hacks, + in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + } + } + } { const auto N = out_n_ho_wo_k_lengths[I0]; diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp new file mode 100644 index 0000000000..d6955ec000 --- /dev/null +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp @@ -0,0 +1,389 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations&, + const InLeftPads&, + const InRightPads&, + Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + const Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0 + Sequence<0, 0, 0>{}, // 1+: gemmm + Sequence<0, 0, 0>{}), // 2+: gemmk1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: gemmk0 + Sequence<0, 0, 0>{}, // 1-: gemmm + Sequence<0, 0, 0>{})); // 2-: gemmk1 + + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0 + Sequence<0, 0, 0>{}, // 1+: gemmn + Sequence<0, 0, 0>{}), // 2+: gemmk1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: Gemmk0 + Sequence<0, 0, 0>{}, // 1-: Gemmn + Sequence<0, 0, 0>{})); // 2-: Gemmk1 + + // clang-format off + constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + // clang-format on + + constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + const auto descs = transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1( + out_n_ho_wo_k_desc, + wei_k_y_x_c_desc, + in_n_hi_wi_c_desc, + conv_strides, + Number{}); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto in_gemmm_gemmn_grid_desc = descs[I2]; + + float ave_time = driver_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(out_gemmk0_gemmm_gemmk1_grid_desc), + decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), + decltype(in_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<2, 0, 1>, + Sequence<0, 2, 1>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy +#if 0 + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, +#else + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, +#endif + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + true, // CAccessOrderMRepeatNRepeat + false, // ABlockLdsExtraM + false // BBlockLdsExtraN + >(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_step_hacks, + in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..ce674758ac --- /dev/null +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp @@ -0,0 +1,258 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp" +#include "driver_gemm_xdlops_v2r4.hpp" + +template +void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + Tensor& wei_k_c_y_x, + const Tensor& out_n_k_ho_wo, + GridSizeType desired_grid_size, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TIn) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); + +#if 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 64, 1>; + // using vector load 4, so config's wo*ho must be a multiple of 4 + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto N = in_n_c_hi_wi_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_desc.GetLength(I1); + + const auto Ho = out_n_k_ho_wo_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_desc.GetLength(I2); + const auto X = wei_k_c_y_x_desc.GetLength(I3); + + const auto GemmM = K; + const auto GemmN = Y * X * C; + const auto GemmKTotal = N * Ho * Wo; + + const auto GemmK = GemmKTotal / GemmK1; + + const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); + const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); + const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch)); + const index_t GemmK0 = BatchLen * GemmKPerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; + + std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN + << " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad + << std::endl; + const auto descs = + transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}, + GemmKBatch, + GemmKPad); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto wei_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmB + Sequence<0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmM + Sequence<0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1 + make_tuple(Sequence<0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemB + Sequence<0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmM + Sequence<0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1 + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmB + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmB + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1 + + constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 1, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{}; + + const auto driver_gemm_xdlops = + driver_gemm_xdlops_v2r4, + Sequence<0, 2, 1, 3>, + 3, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1, + Sequence<0, 2, 1, 3>, + Sequence<0, 2, 1, 3>, + 3, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<3, 0, 1, 2, 7, 5, 4, 6>, + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false, + true, + true>; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast(calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + driver_gemm_xdlops(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + 0); + // copy result back to host + wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..ac75c56bf5 --- /dev/null +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,234 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + Tensor& wei_k_c_y_x, + const Tensor& out_n_k_ho_wo, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TIn) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); + +#if 0 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + // using vector load 4, so config's wo*ho must be a multiple of 4 + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + // using vector load 4, so config's wo*ho must be a multiple of 4 + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto wei_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 1, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 2, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 2, 0, 0>{})); // 2-: GemmK1 + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1 + + constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 1, 0, 0>{}; + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_gemm_xdlops_v2r3< + BlockSize, + TIn, + TAcc, + TWei, + InMemoryDataOperationEnum_t::Set, + decltype(out_gemmk0_gemmm_gemmk1_grid_desc), + decltype(in_gemmk0_gemmn_gemmk1_grid_desc), + decltype(wei_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<3, 0, 1, 2, 7, 5, 4, 6>, + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false, // CAccessOrderMRepeatNRepeat + true, // ABlockLdsExtraM + true // BBlockLdsExtraN + >(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast(calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..579c7a1200 --- /dev/null +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,290 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp" +#include "driver_gemm_xdlops_v2r4.hpp" + +template +void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + Tensor& wei_k_y_x_c, + const Tensor& out_n_ho_wo_k, + GridSizeType desired_grid_size, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto N = in_n_hi_wi_c_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_desc.GetLength(I3); + + const auto Ho = out_n_ho_wo_k_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_desc.GetLength(I1); + const auto X = wei_k_y_x_c_desc.GetLength(I2); + + const auto GemmM = Y * X * C; + const auto GemmN = K; + const auto GemmKTotal = N * Ho * Wo; + + const auto GemmK = GemmKTotal / GemmK1; + + const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); + const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); + const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch)); + const index_t GemmK0 = BatchLen * GemmKPerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; + + std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN + << " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad + << std::endl; + + const auto descs = + transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad( + in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}, + GemmKBatch, + GemmKPad); + + const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto wei_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmKBatch + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmKBatch + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1 + + constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{}; + + constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + const auto driver_gemm_xdlops = driver_gemm_xdlops_v2r4< + BlockSize, + TIn, + TAcc, + TWei, + InMemoryDataOperationEnum_t::AtomicAdd, + decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc), + decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc), + decltype(wei_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerXDL, + GemmNPerXDL, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<0, 1, 2, 3>, + Sequence<0, 1, 2, 3>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmM, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<0, 1, 2, 3>, + Sequence<0, 1, 2, 3>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 6, + GemmCThreadTransferDstScalarPerVector, + decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false, // CAccessOrderMRepeatNRepeat + true, + true>; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + { + + float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + driver_gemm_xdlops(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + 0); + // copy result back to host + wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..bc5d599604 --- /dev/null +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,276 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" +#include "debug.hpp" + +template +void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + Tensor& wei_k_y_x_c, + const Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; + +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( + in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto wei_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1 + + constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{}; + + constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_gemm_xdlops_v2r3< + BlockSize, + TIn, + TAcc, + TWei, + InMemoryDataOperationEnum_t::Set, + decltype(in_gemmk0_gemmm_gemmk1_grid_desc), + decltype(out_gemmk0_gemmn_gemmk1_grid_desc), + decltype(wei_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerXDL, + GemmNPerXDL, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + GemmABlockTransferSrcScalarPerVector_GemmM, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false, // CAccessOrderMRepeatNRepeat + true, + true>(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + in_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + in_gemmk0_gemmm_gemmk1_grid_step_hacks, + out_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..29b404f7d0 --- /dev/null +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,458 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp" +#include "driver_gemm_xdlops_v2r4.hpp" + +template +void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + Tensor& wei_k_y_x_c, + const Tensor& out_n_ho_wo_k, + GridSizeType desired_grid_size, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4], C 128, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4], C 128, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8], C 128, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8], C 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 8], C 64, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [64, 128, 4, 8], C 64, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 64; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 16, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [64, 64, 4, 8], C 32, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 64; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto N = in_n_hi_wi_c_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_desc.GetLength(I3); + + const auto Ho = out_n_ho_wo_k_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_desc.GetLength(I1); + const auto X = wei_k_y_x_c_desc.GetLength(I2); + + const auto GemmM = K; + const auto GemmN = Y * X * C; + const auto GemmKTotal = N * Ho * Wo; + + const auto GemmK = GemmKTotal / GemmK1; + + const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); + const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); + const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch)); + const index_t GemmK0 = BatchLen * GemmKPerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; + + std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN + << " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad + << std::endl; + + const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad( + in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}, + GemmKBatch, + GemmKPad); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto wei_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{}; + + const auto driver_gemm_xdlops = driver_gemm_xdlops_v2r4< + BlockSize, + TIn, + TAcc, + TWei, + InMemoryDataOperationEnum_t::AtomicAdd, + decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc), + decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc), + decltype(wei_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerXDL, + GemmNPerXDL, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<0, 1, 2, 3>, + Sequence<0, 1, 2, 3>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmM, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<0, 1, 2, 3>, + Sequence<0, 1, 3, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false, // CAccessOrderMRepeatNRepeat + true, + true>; + + // timing + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + { + float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // verification + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + driver_gemm_xdlops(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + 0); + // copy result back to host + wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 4a9d01081c..0000000000 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,280 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; - - DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); - -#if 0 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 8; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 0 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 0 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 4] - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 4] - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 4; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#endif - - const auto descs = -#if 1 - transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad -#else - transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1 -#endif - ( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads); - - for(index_t i = 0; i < 5; ++i) - { -#if 0 - float ave_time = launch_kernel_gemm_xdlops_v1 -#else - float ave_time = launch_kernel_gemm_xdlops_v2 -#endif - , - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK, - GemmABlockTransferDstScalarPerVector_KPack, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<0, 2, 1>, - Sequence<1, 0, 2>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_KPack, - false, // don't move back src coordinate after threadwise copy, which will be fused - // with MoveSrcSliceWindow() to save addr computation - Sequence<2, 3, 0, 1>, - 3, - GemmCThreadTransferDstScalarPerVector_GemmN1, - decltype(descs[I4]), - decltype(descs[I5]), - decltype(descs[I6]), - decltype(descs[I7]), - decltype(descs[I8])>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - descs[I0], - descs[I1], - descs[I2], - descs[I3], - descs[I4], - descs[I5], - descs[I6], - descs[I7], - descs[I8], - nrepeat); - - float perf = (float)calculate_convolution_flops( - in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index 695ffeeb36..d65ecadb4d 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -47,7 +47,35 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); -#if 1 +#if 0 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 constexpr index_t BlockSize = 256; @@ -92,36 +120,39 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( const auto out_gemmm_gemmn_grid_desc = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1 - constexpr auto out_m0_m1_m2_n_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); + constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; @@ -169,7 +200,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( GemmCThreadTransferDstScalarPerVector, decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), false>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), @@ -180,7 +211,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( out_gemmm_gemmn_grid_desc, wei_gemmk0_gemmm_gemmk1_grid_step_hacks, in_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_m1_m2_n_grid_step_hacks, + out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 141a326574..0000000000 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,229 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r2.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 1 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#endif - - const auto descs = - transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, - in_n_hi_wi_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - - constexpr auto out_m0_m1_m2_n_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); - - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_xdlops_v2r2< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum_t::Set, - decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), - decltype(in_gemmk0_gemmn_gemmk1_grid_desc), - decltype(out_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmK1, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1>, - 2, - GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_m1_m2_n_grid_step_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks)>( - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - wei_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_m1_m2_n_grid_step_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 692751bfb3..0000000000 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,302 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 1 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [256, 256, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#endif - - const auto descs = - transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, - in_n_hi_wi_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - - constexpr auto out_m0_m1_m2_n_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); - - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_xdlops_v2r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum_t::Set, - decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), - decltype(in_gemmk0_gemmn_gemmk1_grid_desc), - decltype(out_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmK1, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 6, - GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_m1_m2_n_grid_step_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat - >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - wei_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_m1_m2_n_grid_step_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Hi = in_n_hi_wi_c_lengths[I1]; - const auto Wi = in_n_hi_wi_c_lengths[I2]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index 7067291c8a..1b23aa1a8c 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -49,15 +49,15 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); #if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + // [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 256; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; constexpr index_t GemmK1 = 4; constexpr index_t MRepeat = 4; @@ -77,16 +77,16 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 128, for fp32 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 2; @@ -105,16 +105,16 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 0 - // [M, N, K0, K1] = [256, 256, 4, 8] for fp16 + // [M, N, K0, K1] = [256, 256, 4, 8], C = 256, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 256; constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 4; constexpr index_t NRepeat = 4; @@ -133,16 +133,16 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 0 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 256; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 4; constexpr index_t NRepeat = 2; @@ -161,16 +161,16 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 1 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 4; @@ -188,17 +188,17 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 2; @@ -215,6 +215,62 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #endif @@ -249,23 +305,23 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 - constexpr auto out_m0_m1_m2_n_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat - Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat - Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves - Sequence<0, 0, 0, 0, 0>{}, // 3+: NWaves - Sequence<0, 0, 0, 0, 0>{}, // 4+: M0 - Sequence<0, 0, 0, 0, 0>{}, // 5+: M1 - Sequence<0, 0, 0, 0, 0>{}, // 6+: M2 - Sequence<0, 0, 0, 0, 0>{}), // 7+: N1 - make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: MRepeat - Sequence<0, 0, 0, 0, 0>{}, // 1-: NRepeat - Sequence<0, 0, 0, 0, 0>{}, // 2-: MWaves - Sequence<0, 0, 0, 0, 0>{}, // 3-: NWaves - Sequence<0, 0, 0, 0, 0>{}, // 4-: M0 - Sequence<0, 0, 0, 0, 0>{}, // 5-: M1 - Sequence<0, 0, 0, 0, 0>{}, // 6-: M2 - Sequence<0, 0, 0, 0, 0>{})); // 7-: N1 + constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; @@ -287,8 +343,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( GemmMPerBlock, GemmNPerBlock, GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, + GemmMPerXDL, + GemmNPerXDL, GemmK1, MRepeat, NRepeat, @@ -313,19 +369,23 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( GemmCThreadTransferDstScalarPerVector, decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat + false, // CAccessOrderMRepeatNRepeat + true, // ABlockLdsExtraM + true // BBlockLdsExtraN >(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), in_gemmk0_gemmm_gemmk1_grid_desc, wei_gemmk0_gemmn_gemmk1_grid_desc, out_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, in_gemmk0_gemmm_gemmk1_grid_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_m1_m2_n_grid_step_hacks, + out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); diff --git a/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp new file mode 100644 index 0000000000..c44aa7d9a2 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp @@ -0,0 +1,463 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_km_kn_mn(const Tensor& a_k_m, + const Tensor& b_k_n, + Tensor& c_m_n, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); + + a_k_m_device_buf.ToDevice(a_k_m.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4], C = 128, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#endif + + const auto K = a_k_m.mDesc.GetLengths()[0]; + const auto M = a_k_m.mDesc.GetLengths()[1]; + const auto N = b_k_n.mDesc.GetLengths()[1]; + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], + a_k_m.mDesc.GetStrides()[1], + a_k_m.mDesc.GetStrides()[0])); + + const auto b_k0_n_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], + b_k_n.mDesc.GetStrides()[1], + b_k_n.mDesc.GetStrides()[0])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<0, 2, 1>, + 1, + ABlockTransferSrcScalarPerVector_M, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + BBlockTransferSrcScalarPerVector_N, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, + 7, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false, // CAccessOrderMRepeatNRepeat + true, // ABlockLdsExtraM + true // BBlockLdsExtraN + >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp b/host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp new file mode 100644 index 0000000000..abaaf32113 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp @@ -0,0 +1,263 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_km_kn_nm(const Tensor& a_k_m, + const Tensor& b_k_n, + Tensor& c_n_m, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); + + a_k_m_device_buf.ToDevice(a_k_m.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_n_m_device_buf.ToDevice(c_n_m.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#endif + + const auto K = a_k_m.mDesc.GetLengths()[0]; + const auto M = a_k_m.mDesc.GetLengths()[1]; + const auto N = b_k_n.mDesc.GetLengths()[1]; + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], + a_k_m.mDesc.GetStrides()[1], + a_k_m.mDesc.GetStrides()[0])); + + const auto b_k0_n_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], + b_k_n.mDesc.GetStrides()[1], + b_k_n.mDesc.GetStrides()[0])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<0, 2, 1>, + 1, + ABlockTransferSrcScalarPerVector_M, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + BBlockTransferSrcScalarPerVector_N, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 6, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_n_m_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_n_m_device_buf.FromDevice(c_n_m.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp new file mode 100644 index 0000000000..0a97d361d4 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp @@ -0,0 +1,463 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_km_nk_mn(const Tensor& a_k_m, + const Tensor& b_n_k, + Tensor& c_m_n, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); + DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); + + a_k_m_device_buf.ToDevice(a_k_m.mData.data()); + b_n_k_device_buf.ToDevice(b_n_k.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#endif + + const auto K = a_k_m.mDesc.GetLengths()[0]; + const auto M = a_k_m.mDesc.GetLengths()[1]; + const auto N = b_n_k.mDesc.GetLengths()[0]; + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], + a_k_m.mDesc.GetStrides()[1], + a_k_m.mDesc.GetStrides()[0])); + + const auto b_k0_n_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], + b_n_k.mDesc.GetStrides()[0], + b_n_k.mDesc.GetStrides()[1])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<0, 2, 1>, + 1, + ABlockTransferSrcScalarPerVector_M, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + BBlockTransferSrcScalarPerVector_K1, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, + 7, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false, // CAccessOrderMRepeatNRepeat + true, // ABlockLdsExtraM + true // BBlockLdsExtraN + >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp b/host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp new file mode 100644 index 0000000000..d51caa3847 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp @@ -0,0 +1,263 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_km_nk_nm(const Tensor& a_k_m, + const Tensor& b_n_k, + Tensor& c_n_m, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); + DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); + DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); + + a_k_m_device_buf.ToDevice(a_k_m.mData.data()); + b_n_k_device_buf.ToDevice(b_n_k.mData.data()); + c_n_m_device_buf.ToDevice(c_n_m.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#endif + + const auto K = a_k_m.mDesc.GetLengths()[0]; + const auto M = a_k_m.mDesc.GetLengths()[1]; + const auto N = b_n_k.mDesc.GetLengths()[0]; + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], + a_k_m.mDesc.GetStrides()[1], + a_k_m.mDesc.GetStrides()[0])); + + const auto b_k0_n_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], + b_n_k.mDesc.GetStrides()[0], + b_n_k.mDesc.GetStrides()[1])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<0, 2, 1>, + 1, + ABlockTransferSrcScalarPerVector_M, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + BBlockTransferSrcScalarPerVector_K1, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 6, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_n_m_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_n_m_device_buf.FromDevice(c_n_m.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp new file mode 100644 index 0000000000..30ede2517b --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp @@ -0,0 +1,463 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_mk_kn_mn(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#endif + + const auto K = a_m_k.mDesc.GetLengths()[1]; + const auto M = a_m_k.mDesc.GetLengths()[0]; + const auto N = b_k_n.mDesc.GetLengths()[1]; + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], + a_m_k.mDesc.GetStrides()[0], + a_m_k.mDesc.GetStrides()[1])); + + const auto b_k0_n_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], + b_k_n.mDesc.GetStrides()[1], + b_k_n.mDesc.GetStrides()[0])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<1, 0, 2>, + 2, + ABlockTransferSrcScalarPerVector_K1, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + BBlockTransferSrcScalarPerVector_N, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, + 7, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false, // CAccessOrderMRepeatNRepeat + true, // ABlockLdsExtraM + true // BBlockLdsExtraN + >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp new file mode 100644 index 0000000000..58ac3880d6 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp @@ -0,0 +1,291 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_mk_kn_nm(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_n_m, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_n_m_device_buf.ToDevice(c_n_m.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#endif + + const auto K = a_m_k.mDesc.GetLengths()[1]; + const auto M = a_m_k.mDesc.GetLengths()[0]; + const auto N = b_k_n.mDesc.GetLengths()[1]; + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], + a_m_k.mDesc.GetStrides()[0], + a_m_k.mDesc.GetStrides()[1])); + + const auto b_k0_n_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], + b_k_n.mDesc.GetStrides()[1], + b_k_n.mDesc.GetStrides()[0])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<1, 0, 2>, + 2, + ABlockTransferSrcScalarPerVector_K1, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + BBlockTransferSrcScalarPerVector_N, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 6, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_n_m_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_n_m_device_buf.FromDevice(c_n_m.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp new file mode 100644 index 0000000000..e99d570413 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp @@ -0,0 +1,564 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_mk_nk_mn(const Tensor& a_m_k, + const Tensor& b_n_k, + Tensor& c_m_n, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_n_k_device_buf.ToDevice(b_n_k.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#endif + + const auto K = a_m_k.mDesc.GetLengths()[1]; + const auto M = a_m_k.mDesc.GetLengths()[0]; + const auto N = b_n_k.mDesc.GetLengths()[0]; + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + +#if 1 + // non-padded GEMM + const auto a_k0_m_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], + a_m_k.mDesc.GetStrides()[0], + a_m_k.mDesc.GetStrides()[1])); + + const auto b_k0_n_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], + b_n_k.mDesc.GetStrides()[0], + b_n_k.mDesc.GetStrides()[1])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; +#else + // padded GEMM + const auto a_k0_m_k1_grid_desc_tmp = + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], + a_m_k.mDesc.GetStrides()[0], + a_m_k.mDesc.GetStrides()[1])); + + const auto MRightPad = math::integer_divide_ceil(M, MPerBlock) * MPerBlock - M; + + const auto a_k0_m_k1_grid_desc = + transform_tensor_descriptor(a_k0_m_k1_grid_desc_tmp, + make_tuple(make_pass_through_transform(K0), + make_right_pad_transform(M, MRightPad), + make_pass_through_transform(K1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto b_k0_n_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], + b_n_k.mDesc.GetStrides()[0], + b_n_k.mDesc.GetStrides()[1])); + + const auto c_m_n_grid_desc_tmp = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); + + const auto c_m_n_grid_desc = transform_tensor_descriptor( + c_m_n_grid_desc_tmp, + make_tuple(make_right_pad_transform(M, MRightPad), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0, 0>{}, // 1+: M + Sequence<0, 0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0, 0>{}, // 1-: M + Sequence<0, 0, 0, 0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; +#endif + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<1, 0, 2>, + 2, + ABlockTransferSrcScalarPerVector_K1, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + BBlockTransferSrcScalarPerVector_K1, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, + 7, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false, // CAccessOrderMRepeatNRepeat + true, // ABlockLdsExtraM + true // BBlockLdsExtraN + >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp new file mode 100644 index 0000000000..a12cf0733a --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp @@ -0,0 +1,347 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_mk_nk_nm(const Tensor& a_m_k, + const Tensor& b_n_k, + Tensor& c_n_m, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); + DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_n_k_device_buf.ToDevice(b_n_k.mData.data()); + c_n_m_device_buf.ToDevice(c_n_m.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#endif + + const auto K = a_m_k.mDesc.GetLengths()[1]; + const auto M = a_m_k.mDesc.GetLengths()[0]; + const auto N = b_n_k.mDesc.GetLengths()[0]; + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], + a_m_k.mDesc.GetStrides()[0], + a_m_k.mDesc.GetStrides()[1])); + + const auto b_k0_n_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], + b_n_k.mDesc.GetStrides()[0], + b_n_k.mDesc.GetStrides()[1])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<1, 0, 2>, + 2, + ABlockTransferSrcScalarPerVector_K1, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + BBlockTransferSrcScalarPerVector_K1, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 6, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_n_m_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_n_m_device_buf.FromDevice(c_n_m.mData.data()); +} diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp index edfce52a19..4ccfbaab0a 100644 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp @@ -1,5 +1,5 @@ -#ifndef DRIVER_GEMM_XDLOPS_V2R3 -#define DRIVER_GEMM_XDLOPS_V2R3 +#ifndef DRIVER_GEMM_XDLOPS_V2R3_HPP +#define DRIVER_GEMM_XDLOPS_V2R3_HPP #include "common_header.hpp" #include "tensor_descriptor.hpp" @@ -17,8 +17,8 @@ template + bool CAccessOrderMRepeatNRepeat, + bool ABlockLdsAddExtraM, + bool BBlockLdsAddExtraN> __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, const FloatAB* p_b_grid, FloatC* p_c_grid, const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const CMNGridDesc& c_m_n_grid_desc, + ck::index_t M01, + ck::index_t N01, AGridStepHacks, BGridStepHacks, CGridStepHacks, @@ -79,8 +83,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, MPerBlock, NPerBlock, KPerBlock, - MPerWave, - NPerWave, + MPerXDL, + NPerXDL, K1, MRepeat, NRepeat, @@ -108,7 +112,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, CGridStepHacks, AGridMoveSliceWindowStepHacks, BGridMoveSliceWindowStepHacks, - CAccessOrderMRepeatNRepeat>; + CAccessOrderMRepeatNRepeat, + ABlockLdsAddExtraM, + BBlockLdsAddExtraN>; { std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", " @@ -123,68 +129,146 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; } - if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) + if(!GridwiseGemm::CheckValidity( + a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, M01, N01)) { throw std::runtime_error( "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); } - const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = + GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); - using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc); + using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); - const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + const auto c_block_cluster_adaptor = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, M01, N01); using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); - const auto kernel = kernel_gemm_xdlops_v2r3, - remove_reference_t, - remove_reference_t, - remove_reference_t>; + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - float ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, - c_block_cluster_adaptor); + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v2r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true>; + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); + } #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc)); DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc)); - DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc)); + DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc)); DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc); b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc); - c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); - float ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v2r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } +} #endif return ave_time; } diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp new file mode 100644 index 0000000000..65c4f62367 --- /dev/null +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp @@ -0,0 +1,209 @@ +#ifndef DRIVER_GEMM_XDLOPS_V2R4 +#define DRIVER_GEMM_XDLOPS_V2R4 + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r4.hpp" + +template +__host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + ck::index_t M01, + ck::index_t N01, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks, + ck::index_t nrepeat) + +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + using GridwiseGemm = + GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4; + + { + std::cout << "a_b_k0_m_k1_grid_desc{" << a_b_k0_m_k1_grid_desc.GetLength(I0) << ", " + << a_b_k0_m_k1_grid_desc.GetLength(I1) << ", " + << a_b_k0_m_k1_grid_desc.GetLength(I2) << ", " + << a_b_k0_m_k1_grid_desc.GetLength(I3) << "}" << std::endl; + + std::cout << "b_b_k0_n_k1_grid_desc{" << b_b_k0_n_k1_grid_desc.GetLength(I0) << ", " + << b_b_k0_n_k1_grid_desc.GetLength(I1) << ", " + << b_b_k0_n_k1_grid_desc.GetLength(I2) << ", " + << b_b_k0_n_k1_grid_desc.GetLength(I3) << "}" << std::endl; + + std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " + << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity( + a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_m_n_grid_desc, M01, N01)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r4 has invalid setting"); + } + + const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = + GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); + + using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); + + const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); + + const auto c_block_cluster_adaptor = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, M01, N01, KBatch); + + using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); + + const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc, KBatch); + { + std::cout << "gridSize : " << grid_size << std::endl; + } + + const auto kernel = kernel_gemm_xdlops_v2r4, + remove_reference_t, + remove_reference_t, + remove_reference_t>; + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + float ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); + +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_b_k0_m_k1_grid_desc_dev_buf(sizeof(ABK0MK1GridDesc)); + DeviceMem b_b_k0_n_k1_grid_desc_dev_buf(sizeof(BBK0NK1GridDesc)); + DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc)); + DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); + + a_b_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_b_k0_m_k1_grid_desc); + b_b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_b_k0_n_k1_grid_desc); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); + c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); + + float ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_b_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); +#endif + return ave_time; +} +#endif diff --git a/host/driver_offline/src/conv_bwd_driver_offline.cpp b/host/driver_offline/src/conv_bwd_driver_offline.cpp index 67cea94813..366b5dffbc 100644 --- a/host/driver_offline/src/conv_bwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_bwd_driver_offline.cpp @@ -5,6 +5,7 @@ #include #include #include "config.hpp" +#include "debug.hpp" #include "print.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -14,15 +15,16 @@ #include "device_tensor.hpp" #include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp" +#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp" #define USE_MODE 1 -#define USE_CONV_BWD_V4R1_XDL_NHWC 1 +#define USE_CONV_BWD_V4R1_XDL_NHWC 0 #define USE_CONV_BWD_V4R1R2_XDL_NHWC 1 enum ConvBackwardDataAlgo { - V4R1XDLNHWC, - V4R1R2XDLNHWC, + V4R1XDLNHWC, // 0 + V4R1R2XDLNHWC, // 1 }; int main(int argc, char* argv[]) @@ -41,7 +43,7 @@ int main(int argc, char* argv[]) // dynamic mode if(argc != 22) { - printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); exit(1); } @@ -79,7 +81,7 @@ int main(int argc, char* argv[]) // static mode if(argc < 7) { - printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); exit(1); } @@ -90,28 +92,28 @@ int main(int argc, char* argv[]) const bool do_log = std::stoi(argv[5]); const int nrepeat = std::stoi(argv[6]); - constexpr index_t N = 128; - constexpr index_t C = 192; - constexpr index_t Hi = 71; - constexpr index_t Wi = 71; - constexpr index_t K = 256; - constexpr index_t Y = 3; - constexpr index_t X = 3; + constexpr auto N = Number<128>{}; + constexpr auto C = Number<192>{}; + constexpr auto Hi = Number<71>{}; + constexpr auto Wi = Number<71>{}; + constexpr auto K = Number<256>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; - const index_t conv_stride_h = 2; - const index_t conv_stride_w = 2; - const index_t conv_dilation_h = 1; - const index_t conv_dilation_w = 1; - const index_t in_left_pad_h = 1; - const index_t in_left_pad_w = 1; - const index_t in_right_pad_h = 1; - const index_t in_right_pad_w = 1; + constexpr auto conv_stride_h = I2; + constexpr auto conv_stride_w = I2; + constexpr auto conv_dilation_h = I1; + constexpr auto conv_dilation_w = I1; + constexpr auto in_left_pad_h = I1; + constexpr auto in_left_pad_w = I1; + constexpr auto in_right_pad_h = I1; + constexpr auto in_right_pad_w = I1; - const index_t YEff = (Y - 1) * conv_dilation_h + 1; - const index_t XEff = (X - 1) * conv_dilation_w + 1; + constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; + constexpr auto XEff = (X - I1) * conv_dilation_w + I1; - const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; + constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; #endif #if 0 @@ -119,9 +121,9 @@ int main(int argc, char* argv[]) using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #endif std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); @@ -280,20 +282,43 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nhwc(); - device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in_device, - wei, - out, - nrepeat); + if(Y == 1 && X == 1 && in_left_pad_h == 0 && in_left_pad_w == 0 && in_right_pad_h == 0 && + in_right_pad_w == 0) + { + device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1< + in_data_t, + acc_data_t, + out_data_t>(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in_device, + wei, + out, + nrepeat); + } + else + { +#if 1 + device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in_device, + wei, + out, + nrepeat); +#endif + } } #endif diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp index 32c33003c5..48eba2b372 100644 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -5,6 +5,7 @@ #include #include #include "config.hpp" +#include "debug.hpp" #include "print.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -19,13 +20,13 @@ #include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" -#define USE_MODE 1 -#define USE_CONV_FWD_V4R4_NCHW 1 -#define USE_CONV_FWD_V4R4R2_NHWC 1 +#define USE_DYNAMIC_MODE 1 +#define USE_CONV_FWD_V4R4_NCHW 0 +#define USE_CONV_FWD_V4R4R2_NHWC 0 #define USE_CONV_FWD_V6R1_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0 #define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 -#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0 +#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 enum ConvForwardAlgo { @@ -49,11 +50,11 @@ int main(int argc, char* argv[]) constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; -#if USE_MODE +#if USE_DYNAMIC_MODE // dynamic mode if(argc != 22) { - printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); exit(1); } @@ -91,7 +92,7 @@ int main(int argc, char* argv[]) // static mode if(argc < 7) { - printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); exit(1); } @@ -102,38 +103,38 @@ int main(int argc, char* argv[]) const bool do_log = std::stoi(argv[5]); const int nrepeat = std::stoi(argv[6]); - constexpr index_t N = 128; - constexpr index_t C = 192; - constexpr index_t Hi = 71; - constexpr index_t Wi = 71; - constexpr index_t K = 256; - constexpr index_t Y = 3; - constexpr index_t X = 3; + constexpr auto N = Number<128>{}; + constexpr auto C = Number<192>{}; + constexpr auto Hi = Number<71>{}; + constexpr auto Wi = Number<71>{}; + constexpr auto K = Number<256>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; - const index_t conv_stride_h = 2; - const index_t conv_stride_w = 2; - const index_t conv_dilation_h = 1; - const index_t conv_dilation_w = 1; - const index_t in_left_pad_h = 1; - const index_t in_left_pad_w = 1; - const index_t in_right_pad_h = 1; - const index_t in_right_pad_w = 1; + constexpr auto conv_stride_h = I2; + constexpr auto conv_stride_w = I2; + constexpr auto conv_dilation_h = I1; + constexpr auto conv_dilation_w = I1; + constexpr auto in_left_pad_h = I1; + constexpr auto in_left_pad_w = I1; + constexpr auto in_right_pad_h = I1; + constexpr auto in_right_pad_w = I1; - const index_t YEff = (Y - 1) * conv_dilation_h + 1; - const index_t XEff = (X - 1) * conv_dilation_w + 1; + constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; + constexpr auto XEff = (X - I1) * conv_dilation_w + I1; - const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; + constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; #endif -#if 1 +#if 0 using in_data_t = float; using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #elif 1 using in_data_t = int8_t; using acc_data_t = int32_t; @@ -228,7 +229,6 @@ int main(int argc, char* argv[]) } auto f_make_for_device_nchw = [&]() { -#if USE_MODE const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); const auto wei_lengths_dev = make_tuple(K, C, Y, X); const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); @@ -236,19 +236,6 @@ int main(int argc, char* argv[]) const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); -#else - const auto in_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto out_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto conv_strides_dev = make_tuple(Number{}, Number{}); - const auto conv_dilations_dev = - make_tuple(Number{}, Number{}); - const auto in_left_pads_dev = make_tuple(Number{}, Number{}); - const auto in_right_pads_dev = - make_tuple(Number{}, Number{}); -#endif return make_tuple(in_lengths_dev, wei_lengths_dev, @@ -260,7 +247,6 @@ int main(int argc, char* argv[]) }; auto f_make_for_device_nhwc = [&]() { -#if USE_MODE const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); const auto wei_lengths_dev = make_tuple(K, Y, X, C); const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); @@ -268,19 +254,6 @@ int main(int argc, char* argv[]) const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); -#else - const auto in_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto out_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto conv_strides_dev = make_tuple(Number{}, Number{}); - const auto conv_dilations_dev = - make_tuple(Number{}, Number{}); - const auto in_left_pads_dev = make_tuple(Number{}, Number{}); - const auto in_right_pads_dev = - make_tuple(Number{}, Number{}); -#endif return make_tuple(in_lengths_dev, wei_lengths_dev, diff --git a/host/driver_offline/src/conv_wrw_driver_offline.cpp b/host/driver_offline/src/conv_wrw_driver_offline.cpp new file mode 100644 index 0000000000..50f4d6a9b3 --- /dev/null +++ b/host/driver_offline/src/conv_wrw_driver_offline.cpp @@ -0,0 +1,436 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "debug.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "conv_common.hpp" +#include "host_conv_bwd_weight.hpp" +#include "device_tensor.hpp" +#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" +#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" +#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp" +#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp" +#include "device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp" + +#define USE_DYNAMIC_MODE 1 +#define USE_CONV_WRW_V4R4R2_XDL_NCHW 0 +#define USE_CONV_WRW_V4R4R4_XDL_NHWC 0 +#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 0 +#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 0 +#define USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC 1 + +enum ConvBackwardWeightAlgo +{ + V4R4R2XDLNCHW, // 0 + V4R4R4XDLNHWC, // 1 + V4R4R2XDLATOMICNCHW, // 2 + V4R4R4XDLATOMICNHWC, // 3 + V4R4R5XDLATOMICNHWC, // 4 +}; + +int main(int argc, char* argv[]) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + +#if USE_DYNAMIC_MODE + // dynamic mode + if(argc != 23) + { + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); + printf("additional: desired_grid_size\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); + const ConvBackwardWeightAlgo algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); + + const index_t N = std::stoi(argv[7]); + const index_t K = std::stoi(argv[8]); + const index_t C = std::stoi(argv[9]); + const index_t Y = std::stoi(argv[10]); + const index_t X = std::stoi(argv[11]); + const index_t Hi = std::stoi(argv[12]); + const index_t Wi = std::stoi(argv[13]); + + const index_t conv_stride_h = std::stoi(argv[14]); + const index_t conv_stride_w = std::stoi(argv[15]); + const index_t conv_dilation_h = std::stoi(argv[16]); + const index_t conv_dilation_w = std::stoi(argv[17]); + const index_t in_left_pad_h = std::stoi(argv[18]); + const index_t in_left_pad_w = std::stoi(argv[19]); + const index_t in_right_pad_h = std::stoi(argv[20]); + const index_t in_right_pad_w = std::stoi(argv[21]); + + const index_t desired_grid_size = std::stoi(argv[22]); + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; +#else + // static mode + if(argc < 7) + { + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); + const ConvBackwardWeightAlgo algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); + + constexpr auto N = Number<128>{}; + constexpr auto C = Number<128>{}; + constexpr auto Hi = Number<14>{}; + constexpr auto Wi = Number<14>{}; + constexpr auto K = Number<256>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; + + constexpr auto conv_stride_h = I1; + constexpr auto conv_stride_w = I1; + constexpr auto conv_dilation_h = I1; + constexpr auto conv_dilation_w = I1; + constexpr auto in_left_pad_h = I1; + constexpr auto in_left_pad_w = I1; + constexpr auto in_right_pad_h = I1; + constexpr auto in_right_pad_w = I1; + + constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; + constexpr auto XEff = (X - I1) * conv_dilation_w + I1; + + constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; + constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; +#endif + +#if 0 + using in_data_t = float; + using wei_data_t = float; + using acc_data_t = float; + using out_data_t = float; +#elif 1 + using in_data_t = half_t; + using out_data_t = half_t; + using acc_data_t = float; + using wei_data_t = float; +#elif 1 + using in_data_t = int8_t; + using out_data_t = int8_t; + using acc_data_t = int32_t; + using wei_data_t = int8_t; +#endif + + std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); + + if(layout == ConvTensorLayout::NCHW) + { + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(C); + in_lengths_host[2] = static_cast(Hi); + in_lengths_host[3] = static_cast(Wi); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(C); + wei_lengths_host[2] = static_cast(Y); + wei_lengths_host[3] = static_cast(X); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(K); + out_lengths_host[2] = static_cast(Ho); + out_lengths_host[3] = static_cast(Wo); + } + else if(layout == ConvTensorLayout::NHWC) + { + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(Hi); + in_lengths_host[2] = static_cast(Wi); + in_lengths_host[3] = static_cast(C); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(Y); + wei_lengths_host[2] = static_cast(X); + wei_lengths_host[3] = static_cast(C); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(Ho); + out_lengths_host[2] = static_cast(Wo); + out_lengths_host[3] = static_cast(K); + } + else + { + std::runtime_error("wrong! not implemented"); + } + + Tensor in(in_lengths_host); + Tensor wei_device(wei_lengths_host); + Tensor wei_host(wei_lengths_host); + Tensor out(out_lengths_host); + + std::cout << "layout: " << layout << std::endl; + ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: "); + ostream_HostTensorDescriptor(wei_host.mDesc, std::cout << "wei: "); + ostream_HostTensorDescriptor(out.mDesc, std::cout << "out: "); + print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); + print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); + print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); + print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + switch(init_method) + { + case 0: + // no initialization + break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 3: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 4: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 5: + in.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); + out.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + + auto gen_out = [](auto... is) { + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + }; + out.GenerateTensorValue(gen_out, num_thread); + } + + auto f_make_for_device_nchw = [&]() { + const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); + const auto wei_lengths_dev = make_tuple(K, C, Y, X); + const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + + auto f_make_for_device_nhwc = [&]() { + const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); + const auto wei_lengths_dev = make_tuple(K, Y, X, C); + const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + + // set zero to wei_device + wei_device.GenerateTensorValue(GeneratorTensor_0{}, num_thread); +#if USE_CONV_WRW_V4R4R2_XDL_NCHW + if(algo == ConvBackwardWeightAlgo::V4R4R2XDLNCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei_device, + out, + nrepeat); + } +#endif + +#if USE_CONV_WRW_V4R4R4_XDL_NHWC + if(algo == ConvBackwardWeightAlgo::V4R4R4XDLNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei_device, + out, + nrepeat); + } +#endif + +#if USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW + if(algo == ConvBackwardWeightAlgo::V4R4R2XDLATOMICNCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw< + in_data_t, + wei_data_t, + acc_data_t, + out_data_t>(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei_device, + out, + desired_grid_size, + nrepeat); + } +#endif + +#if USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC + if(algo == ConvBackwardWeightAlgo::V4R4R4XDLATOMICNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk< + in_data_t, + wei_data_t, + acc_data_t, + out_data_t>(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei_device, + out, + desired_grid_size, + nrepeat); + } +#endif + +#if USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC + if(algo == ConvBackwardWeightAlgo::V4R4R5XDLATOMICNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk< + in_data_t, + wei_data_t, + acc_data_t, + out_data_t>(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei_device, + out, + desired_grid_size, + nrepeat); + } +#endif + + if(do_verification) + { + host_direct_convolution_backward_weights(out, + in, + wei_host, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + layout); + + check_error(wei_host, wei_device); + + if(do_log) + { + LogRangeAsType(std::cout << "out: ", out.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; + LogRangeAsType(std::cout << "wei_device: ", wei_device.mData, ",") << std::endl; + LogRangeAsType(std::cout << "wei_host : ", wei_host.mData, ",") << std::endl; + } + } +} diff --git a/host/driver_offline/src/gemm_driver_offline.cpp b/host/driver_offline/src/gemm_driver_offline.cpp new file mode 100644 index 0000000000..e60b4905ae --- /dev/null +++ b/host/driver_offline/src/gemm_driver_offline.cpp @@ -0,0 +1,288 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "debug.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "gemm_common.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdlops_mk_kn_mn.hpp" +#include "device_gemm_xdlops_mk_nk_mn.hpp" +#include "device_gemm_xdlops_km_kn_mn.hpp" +#include "device_gemm_xdlops_km_nk_mn.hpp" +#include "device_gemm_xdlops_mk_kn_nm.hpp" +#include "device_gemm_xdlops_mk_nk_nm.hpp" +#include "device_gemm_xdlops_km_kn_nm.hpp" +#include "device_gemm_xdlops_km_nk_nm.hpp" + +#define USE_GEMM_XDL_MK_KN_MN 1 +#define USE_GEMM_XDL_MK_NK_MN 1 +#define USE_GEMM_XDL_KM_KN_MN 1 +#define USE_GEMM_XDL_KM_NK_MN 1 +#define USE_GEMM_XDL_MK_KN_NM 0 +#define USE_GEMM_XDL_MK_NK_NM 0 +#define USE_GEMM_XDL_KM_KN_NM 0 +#define USE_GEMM_XDL_KM_NK_NM 0 + +enum GemmAlgo +{ + Xdl_MK_KN_MN, // 0 + Xdl_MK_NK_MN, // 1 + Xdl_KM_KN_MN, // 2 + Xdl_KM_NK_MN, // 3 + Xdl_MK_KN_NM, // 4 + Xdl_MK_NK_NM, // 5 + Xdl_KM_KN_NM, // 6 + Xdl_KM_NK_NM, // 7 +}; + +int main(int argc, char* argv[]) +{ + using namespace ck; + + if(argc != 12) + { + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("rest: M, N, K\n"); + printf("debug_driver_gemm_xdlops_v2r3::M01, debug_driver_gemm_xdlops_v2r3::N01\n"); + exit(1); + } + + const auto layout = static_cast(std::stoi(argv[1])); + const auto algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); + + const index_t M = std::stoi(argv[7]); + const index_t N = std::stoi(argv[8]); + const index_t K = std::stoi(argv[9]); + + debug::debug_driver_gemm_xdlops_v2r3::M01 = std::stoi(argv[10]); + debug::debug_driver_gemm_xdlops_v2r3::N01 = std::stoi(argv[11]); + +#if 0 + using ab_data_t = float; + using acc_data_t = float; + using c_data_t = float; +#elif 1 + using ab_data_t = half_t; + using acc_data_t = float; + using c_data_t = half_t; +#elif 1 + using ab_data_t = int8_t; + using acc_data_t = int32_t; + using c_data_t = int8_t; +#endif + + std::vector a_lengths_host(2), b_lengths_host(2), c_lengths_host(2); + std::vector a_strides_host(2), b_strides_host(2), c_strides_host(2); + + // A + if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::MK_NK_MN || + layout == GemmMatrixLayout::MK_KN_NM || layout == GemmMatrixLayout::MK_NK_NM) + { + a_lengths_host[0] = static_cast(M); + a_lengths_host[1] = static_cast(K); + a_strides_host[0] = static_cast(K); + a_strides_host[1] = static_cast(1); + } + else + { + a_lengths_host[0] = static_cast(K); + a_lengths_host[1] = static_cast(M); + a_strides_host[0] = static_cast(M); + a_strides_host[1] = static_cast(1); + } + + // B + if(layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN || + layout == GemmMatrixLayout::MK_NK_NM || layout == GemmMatrixLayout::KM_NK_NM) + { + b_lengths_host[0] = static_cast(N); + b_lengths_host[1] = static_cast(K); + b_strides_host[0] = static_cast(K); + b_strides_host[1] = static_cast(1); + } + else + { + b_lengths_host[0] = static_cast(K); + b_lengths_host[1] = static_cast(N); + b_strides_host[0] = static_cast(N); + b_strides_host[1] = static_cast(1); + } + + // C + if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::KM_KN_MN || + layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN) + { + c_lengths_host[0] = static_cast(M); + c_lengths_host[1] = static_cast(N); + c_strides_host[0] = static_cast(N); + c_strides_host[1] = static_cast(1); + } + else + { + c_lengths_host[0] = static_cast(N); + c_lengths_host[1] = static_cast(M); + c_strides_host[0] = static_cast(M); + c_strides_host[1] = static_cast(1); + } + + Tensor a(a_lengths_host, a_strides_host); + Tensor b(b_lengths_host, b_strides_host); + Tensor c_host(c_lengths_host, c_strides_host); + Tensor c_device(c_lengths_host, c_strides_host); + + std::cout << "layout: " << layout << std::endl; + ostream_HostTensorDescriptor(a.mDesc, std::cout << "a: "); + ostream_HostTensorDescriptor(b.mDesc, std::cout << "b: "); + ostream_HostTensorDescriptor(c_host.mDesc, std::cout << "c: "); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + switch(init_method) + { + case 0: + // no initialization + break; + case 1: + a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 2: + a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 3: + a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 4: + a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + +#if USE_GEMM_XDL_MK_KN_MN + if(algo == GemmAlgo::Xdl_MK_KN_MN) + { + if(layout != GemmMatrixLayout::MK_KN_MN) + { + throw std::runtime_error("wrong! layout"); + } + + device_gemm_xdlops_mk_kn_mn(a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_MK_NK_MN + if(algo == GemmAlgo::Xdl_MK_NK_MN) + { + if(layout != GemmMatrixLayout::MK_NK_MN) + { + throw std::runtime_error("wrong! layout"); + } + + device_gemm_xdlops_mk_nk_mn(a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_KM_KN_MN + if(algo == GemmAlgo::Xdl_KM_KN_MN) + { + if(layout != GemmMatrixLayout::KM_KN_MN) + { + throw std::runtime_error("wrong! layout"); + } + + device_gemm_xdlops_km_kn_mn(a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_KM_NK_MN + if(algo == GemmAlgo::Xdl_KM_NK_MN) + { + if(layout != GemmMatrixLayout::KM_NK_MN) + { + throw std::runtime_error("wrong! layout"); + } + + device_gemm_xdlops_km_nk_mn(a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_MK_KN_NM + if(algo == GemmAlgo::Xdl_MK_KN_NM) + { + if(layout != GemmMatrixLayout::MK_KN_NM) + { + throw std::runtime_error("wrong! layout"); + } + + device_gemm_xdlops_mk_kn_nm(a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_MK_NK_NM + if(algo == GemmAlgo::Xdl_MK_NK_NM) + { + if(layout != GemmMatrixLayout::MK_NK_NM) + { + throw std::runtime_error("wrong! layout"); + } + + device_gemm_xdlops_mk_nk_nm(a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_KM_KN_NM + if(algo == GemmAlgo::Xdl_KM_KN_NM) + { + if(layout != GemmMatrixLayout::KM_KN_NM) + { + throw std::runtime_error("wrong! layout"); + } + + device_gemm_xdlops_km_kn_nm(a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_KM_NK_NM + if(algo == GemmAlgo::Xdl_KM_NK_NM) + { + if(layout != GemmMatrixLayout::KM_NK_NM) + { + throw std::runtime_error("wrong! layout"); + } + + device_gemm_xdlops_km_nk_nm(a, b, c_device, nrepeat); + } +#endif + + if(do_verification) + { + host_gemm(a, b, c_host, layout); + + check_error(c_host, c_device); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_host.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_device.mData, ",") << std::endl; + } + } +} diff --git a/host/host_tensor/include/device.hpp b/host/host_tensor/include/device.hpp index e2cba94100..cb1a6effa1 100644 --- a/host/host_tensor/include/device.hpp +++ b/host/host_tensor/include/device.hpp @@ -2,6 +2,9 @@ #define DEVICE_HPP #include +#include +#include +#include #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" @@ -74,7 +77,8 @@ float launch_and_time_kernel( timer.End(); + // std::this_thread::sleep_for (std::chrono::microseconds(10)); + return timer.GetElapsedTime() / nrepeat; } - #endif diff --git a/host/host_tensor/include/gemm_common.hpp b/host/host_tensor/include/gemm_common.hpp new file mode 100644 index 0000000000..f6c0d6f930 --- /dev/null +++ b/host/host_tensor/include/gemm_common.hpp @@ -0,0 +1,16 @@ +#ifndef GEMM_COMMON_HPP +#define GEMM_COMMON_HPP + +enum GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + MK_KN_NM, // 4 + MK_NK_NM, // 5 + KM_KN_NM, // 6 + KM_NK_NM, // 7 +}; + +#endif diff --git a/host/host_tensor/include/host_conv_bwd_weight.hpp b/host/host_tensor/include/host_conv_bwd_weight.hpp new file mode 100644 index 0000000000..ed3e8c3042 --- /dev/null +++ b/host/host_tensor/include/host_conv_bwd_weight.hpp @@ -0,0 +1,89 @@ +#pragma once +#include "host_tensor.hpp" + +template +void host_direct_convolution_backward_weights( + const Tensor& out, + const Tensor& in, + Tensor& wei, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads&, + const ConvTensorLayout layout = ConvTensorLayout::NCHW) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + auto f_kcyx = [&](auto k, auto c, auto y, auto x) { + double v = 0; + for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n) + { + for(int ho = 0; ho < out.mDesc.GetLengths()[2]; ++ho) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int wo = 0; wo < out.mDesc.GetLengths()[3]; ++wo) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + v += static_cast(in(n, c, hi, wi)) * + static_cast(out(n, k, ho, wo)); + } + } + } + } + wei(k, c, y, x) = v; + }; + + auto f_kyxc = [&](auto k, auto y, auto x, auto c) { + double v = 0; + for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n) + { + for(int ho = 0; ho < out.mDesc.GetLengths()[1]; ++ho) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int wo = 0; wo < out.mDesc.GetLengths()[2]; ++wo) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && + wi < in.mDesc.GetLengths()[2]) + { + v += static_cast(in(n, hi, wi, c)) * + static_cast(out(n, ho, wo, k)); + } + } + } + } + wei(k, y, x, c) = v; + }; + + if(layout == ConvTensorLayout::NCHW) + { + make_ParallelTensorFunctor(f_kcyx, + wei.mDesc.GetLengths()[0], + wei.mDesc.GetLengths()[1], + wei.mDesc.GetLengths()[2], + wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else if(layout == ConvTensorLayout::NHWC) + { + make_ParallelTensorFunctor(f_kyxc, + wei.mDesc.GetLengths()[0], + wei.mDesc.GetLengths()[1], + wei.mDesc.GetLengths()[2], + wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error("wrong! not supported layout"); + } +} diff --git a/host/host_tensor/include/host_gemm.hpp b/host/host_tensor/include/host_gemm.hpp new file mode 100644 index 0000000000..c582a34258 --- /dev/null +++ b/host/host_tensor/include/host_gemm.hpp @@ -0,0 +1,159 @@ +#pragma once +#include "host_tensor.hpp" +#include "gemm_common.hpp" + +template +void host_gemm(const Tensor& a, + const Tensor& b, + Tensor& c, + const GemmMatrixLayout layout) +{ + if(layout == GemmMatrixLayout::MK_KN_MN) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(m, k)) * static_cast(b(k, n)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::MK_NK_MN) + { + auto f_mk_nk_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(m, k)) * static_cast(b(n, k)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_KN_MN) + { + auto f_km_kn_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(k, m)) * static_cast(b(k, n)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_NK_MN) + { + auto f_km_nk_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(k, m)) * static_cast(b(n, k)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::MK_KN_NM) + { + auto f_mk_kn_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(m, k)) * static_cast(b(k, n)); + } + + c(n, m) = v; + }; + + make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::MK_NK_NM) + { + auto f_mk_nk_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(m, k)) * static_cast(b(n, k)); + } + + c(n, m) = v; + }; + + make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_KN_NM) + { + auto f_km_kn_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(k, m)) * static_cast(b(k, n)); + } + + c(n, m) = v; + }; + + make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_NK_NM) + { + auto f_km_nk_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(k, m)) * static_cast(b(n, k)); + } + + c(n, m) = v; + }; + + make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error("wrong! not supported layout"); + } +} diff --git a/host/host_tensor/include/host_tensor_generator.hpp b/host/host_tensor/include/host_tensor_generator.hpp index 7c09843d01..b0d53995ed 100644 --- a/host/host_tensor/include/host_tensor_generator.hpp +++ b/host/host_tensor/include/host_tensor_generator.hpp @@ -15,6 +15,17 @@ struct GeneratorTensor_1 } }; +struct GeneratorTensor_0 +{ + int value = 0; + + template + float operator()(Is...) + { + return value; + } +}; + struct GeneratorTensor_2 { int min_value = 0; diff --git a/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp index 97ce326346..361f6e4a26 100644 --- a/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp @@ -9,8 +9,8 @@ struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw int NPerBlock; int KPerBlock; - int MPerWave; - int NPerWave; + int MPerXDL; + int NPerXDL; int K1; int MRepeat; @@ -45,8 +45,8 @@ static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw 128, // MPerBlock, 128, // NPerBlock, 4, // KPerBlock, - 32, // MPerWave, - 32, // NPerWave, + 32, // MPerXDL, + 32, // NPerXDL, 4, // K1, 2, // MRepeat, 2, // NRepeat, diff --git a/script/docker-rocm4.3.1.sh b/script/docker-rocm4.3.1.sh new file mode 100755 index 0000000000..48cb675b69 --- /dev/null +++ b/script/docker-rocm4.3.1.sh @@ -0,0 +1,14 @@ +WORKSPACE=$1 +echo "workspace: " $WORKSPACE + +docker run \ +-it \ +--rm \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v $WORKSPACE:/root/workspace \ +rocm/tensorflow:rocm4.3.1-tf2.6-dev \ +/bin/bash + +#--network host \ diff --git a/script/run.sh b/script/run.sh index ecb5c85d81..1ff56b2295 100755 --- a/script/run.sh +++ b/script/run.sh @@ -4,21 +4,12 @@ export ROCR_VISIBLE_DEVICE=0 export GPU_DEVICE_ORDINAL=0 -## Boost - export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH - -## Compiling -#export OLC_DEBUG_HIP_VERBOSE=1 -#export OLC_DEBUG_HIP_DUMP=1 -#export OLC_DEBUG_SAVE_TEMP_DIR=1 - make -j conv_fwd_driver_offline - make -j conv_bwd_driver_offline - make -j conv_fwd_driver_online - -#rm -rf /root/_hip_binary_kernels_/ -#rm -rf /tmp/olCompile* +#make -j conv_bwd_driver_offline +#make -j conv_wrw_driver_offline +#make -j gemm_driver_offline +DRIVER="./host/driver_offline/conv_fwd_driver_offline" LAYOUT=$1 ALGO=$2 VERIFY=$3 @@ -26,22 +17,121 @@ INIT=$4 LOG=$5 REPEAT=$6 -################################################ layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 - ./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 +#M01=$7 +#N01=$8 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 + KBATCH=$7 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 +######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 -#./host/driver_offline/conv_bwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 -#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 + +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 + +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 3 3 14 14 1 1 1 1 1 1 1 1 + +######### layout algo verify init log repeat M___ N___ K___ +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01 + +# Resnet50 +######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 28 28 1 1 1 1 1 1 1 1 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 128 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 58 58 2 2 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 256 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 256 1 1 56 56 2 2 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 256 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 512 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 512 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 64 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 3 3 56 56 1 1 1 1 1 1 1 1 + +# 256x128x32 c64 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 56 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 56 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 224 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH + + + +# 128x128x32 c64 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 28 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 112 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 112 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 448 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 28 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 224 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 112 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH + + +# 128x64x32 c64 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112 + +# 64x128x32 c64 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH + +# 64x64x32 c32 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 112 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 448 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 448