mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Merge branch 'develop' into miopen_downstream_all
This commit is contained in:
@@ -21,8 +21,8 @@ template <typename... Wei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t IYTildaValue,
|
||||
index_t IXTildaValue,
|
||||
typename IYTilda,
|
||||
typename IXTilda,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
@@ -33,8 +33,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<IYTildaValue>,
|
||||
Number<IXTildaValue>,
|
||||
IYTilda i_ytilda,
|
||||
IXTilda i_xtilda,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
@@ -42,9 +42,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
constexpr auto IYTilda = Number<IYTildaValue>{};
|
||||
constexpr auto IXTilda = Number<IXTildaValue>{};
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
@@ -98,8 +96,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin;
|
||||
|
||||
// GemmK is different for each GEMM
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda);
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilda, YTilda);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilda, XTilda);
|
||||
|
||||
const auto K1 = GemmK1;
|
||||
const auto K0 = K / K1;
|
||||
@@ -183,8 +181,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_freeze_transform(IYTilda),
|
||||
make_freeze_transform(IXTilda),
|
||||
make_freeze_transform(i_ytilda),
|
||||
make_freeze_transform(i_xtilda),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
@@ -241,9 +239,9 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_freeze_transform(IYTilda),
|
||||
make_freeze_transform(i_ytilda),
|
||||
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
|
||||
make_freeze_transform(IXTilda),
|
||||
make_freeze_transform(i_xtilda),
|
||||
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
@@ -271,5 +269,84 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
in_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
// A: out
|
||||
// B: wei
|
||||
// C: in
|
||||
// Number of GEMMs = 1
|
||||
// GemmM = N * Ho * Wo
|
||||
// GemmN = C
|
||||
// GemmK = K
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1(
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const TensorDescriptor<Wei...>& /* wei_k_y_x_c_grid_desc */,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto K1 = GemmK1;
|
||||
const auto K0 = K / K1;
|
||||
|
||||
// A: output tensor
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_unmerge_transform(make_tuple(K0, K1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: input tensor
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_freeze_transform(I0),
|
||||
make_freeze_transform(I0),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmK = N * Ho * Wo
|
||||
// GemmN = C * Y * X
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmK1Value,
|
||||
typename GemmKBatchType,
|
||||
typename GemmKPadType>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad(
|
||||
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<GemmK1Value>,
|
||||
GemmKBatchType GemmKBatch,
|
||||
GemmKPadType GemmKPad)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = K;
|
||||
const auto GemmN = C * Y * X;
|
||||
const auto GemmKTotal = N * Ho * Wo;
|
||||
const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1);
|
||||
|
||||
// A: output tensor
|
||||
const auto out_gemmktotal_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// B: input tensor
|
||||
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_gemmktotal_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,129 @@
|
||||
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmK = N * Ho * Wo
|
||||
// GemmN = C * Y * X
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
||||
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = K;
|
||||
const auto GemmN = C * Y * X;
|
||||
const auto GemmK = N * Ho * Wo;
|
||||
const auto GemmK0 = GemmK / GemmK1;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,147 @@
|
||||
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// A: in
|
||||
// B: wei
|
||||
// C: out
|
||||
// GemmM = N * Ho * Wo
|
||||
// GemmN = K
|
||||
// GemmK = Y * X * C
|
||||
template <typename... In,
|
||||
typename... Wei,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmK1Value,
|
||||
typename GemmKBatchType,
|
||||
typename GemmKPadType>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad(
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<GemmK1Value>,
|
||||
GemmKBatchType GemmKBatch,
|
||||
GemmKPadType GemmKPad)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = Y * X * C;
|
||||
const auto GemmN = K;
|
||||
const auto GemmKTotal = N * Ho * Wo;
|
||||
const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1);
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmktotal_gemmm_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto in_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// B: output tensor
|
||||
const auto out_gemmktotal_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
return make_tuple(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,132 @@
|
||||
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// A: in
|
||||
// B: wei
|
||||
// C: out
|
||||
// GemmM = N * Ho * Wo
|
||||
// GemmN = K
|
||||
// GemmK = Y * X * C
|
||||
template <typename... In,
|
||||
typename... Wei,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = Y * X * C;
|
||||
const auto GemmN = K;
|
||||
const auto GemmK = N * Ho * Wo;
|
||||
const auto GemmK0 = GemmK / GemmK1;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmm_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(in_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: output tensor
|
||||
const auto out_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmk0_gemmn_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
out_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,144 @@
|
||||
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// A: out
|
||||
// B: in
|
||||
// C: wei
|
||||
// GemmM = K
|
||||
// GemmN = Y * X * C
|
||||
// GemmKTotal = N * Ho * Wo
|
||||
template <typename... In,
|
||||
typename... Wei,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmK1Value,
|
||||
typename GemmKBatchType,
|
||||
typename GemmKPadType>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad(
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<GemmK1Value>,
|
||||
GemmKBatchType GemmKBatch,
|
||||
GemmKPadType GemmKPad)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = K;
|
||||
const auto GemmN = Y * X * C;
|
||||
const auto GemmKTotal = N * Ho * Wo;
|
||||
const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1);
|
||||
|
||||
// A: output tensor
|
||||
const auto out_gemmktotal_gemmm_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// B: input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmktotal_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1327,6 +1327,129 @@ struct Merge_v2r2_magic_division
|
||||
}
|
||||
};
|
||||
|
||||
// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to
|
||||
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
|
||||
// will be very bad
|
||||
template <typename LowLengths>
|
||||
struct Merge_v3_division_mod
|
||||
{
|
||||
static constexpr index_t NDimLow = LowLengths::Size();
|
||||
|
||||
using LowerIndex = MultiIndex<NDimLow>;
|
||||
using UpperIndex = MultiIndex<1>;
|
||||
|
||||
using LowLengthsScan =
|
||||
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
|
||||
|
||||
using UpLengths =
|
||||
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
|
||||
|
||||
LowLengths low_lengths_;
|
||||
LowLengthsScan low_lengths_scan_;
|
||||
UpLengths up_lengths_;
|
||||
|
||||
__host__ __device__ constexpr Merge_v3_division_mod() = default;
|
||||
|
||||
__host__ __device__ constexpr Merge_v3_division_mod(const LowLengths& low_lengths)
|
||||
: low_lengths_{low_lengths},
|
||||
low_lengths_scan_{
|
||||
container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
|
||||
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
|
||||
{
|
||||
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
|
||||
|
||||
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
|
||||
const UpIdx& idx_up) const
|
||||
{
|
||||
static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
index_t tmp = idx_up[Number<0>{}];
|
||||
|
||||
// division and mod
|
||||
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
|
||||
idx_low(i) = tmp / this->low_lengths_scan_[i];
|
||||
tmp %= this->low_lengths_scan_[i];
|
||||
});
|
||||
|
||||
idx_low(Number<NDimLow - 1>{}) = tmp;
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff,
|
||||
typename UpIdxDiff,
|
||||
typename LowIdx,
|
||||
typename UpIdx,
|
||||
index_t Hack>
|
||||
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff&,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up_new,
|
||||
Number<Hack>) const
|
||||
{
|
||||
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
|
||||
LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto INm1 = Number<NDimLow - 1>{};
|
||||
|
||||
index_t tmp = idx_up_new[I0];
|
||||
|
||||
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
|
||||
const index_t tmp2 = idx_low[i];
|
||||
idx_low(i) = tmp / this->low_lengths_scan_[i];
|
||||
idx_diff_low(i) = idx_low[i] - tmp2;
|
||||
tmp %= this->low_lengths_scan_[i];
|
||||
});
|
||||
|
||||
const index_t tmp2 = idx_low[INm1];
|
||||
idx_low(INm1) = tmp;
|
||||
idx_diff_low(INm1) = idx_low[INm1] - tmp2;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
|
||||
{
|
||||
return is_known_at_compile_time<LowLengths>::value &&
|
||||
is_known_at_compile_time<LowLengthsScan>::value &&
|
||||
is_known_at_compile_time<UpLengths>::value;
|
||||
}
|
||||
|
||||
template <typename UpIdx>
|
||||
__host__ __device__ static constexpr bool
|
||||
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("Merge_v3_direct_division_mod, ");
|
||||
printf("low_lengths_ ");
|
||||
print_multi_index(low_lengths_);
|
||||
printf("low_lengths_scan_ ");
|
||||
print_multi_index(low_lengths_scan_);
|
||||
printf("up_lengths_ ");
|
||||
print_multi_index(up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UpLengths, bool Use24BitIntegerCalculation>
|
||||
struct UnMerge
|
||||
{
|
||||
|
||||
@@ -31,7 +31,7 @@ __host__ __device__ constexpr auto make_left_pad_transform(
|
||||
return LeftPad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad};
|
||||
}
|
||||
|
||||
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck>
|
||||
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
|
||||
__host__ __device__ constexpr auto make_right_pad_transform(
|
||||
const LowLength& low_length,
|
||||
const RightPadLength& right_pad,
|
||||
@@ -52,22 +52,36 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng
|
||||
template <typename LowLengths>
|
||||
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
|
||||
{
|
||||
#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
|
||||
#if CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
|
||||
return make_merge_transform_v2_magic_division(low_lengths);
|
||||
#else
|
||||
return make_merge_transform_v1_carry_check(low_lengths);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename LowLengths>
|
||||
__host__ __device__ constexpr auto
|
||||
make_merge_transform_v1_carry_check(const LowLengths& low_lengths)
|
||||
{
|
||||
return Merge_v1_carry_check<LowLengths>{low_lengths};
|
||||
#else
|
||||
#if 1
|
||||
return Merge_v2_magic_division<LowLengths>{low_lengths};
|
||||
#else
|
||||
return Merge_v2r2_magic_division<LowLengths>{low_lengths};
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename LowLengths>
|
||||
__host__ __device__ constexpr auto
|
||||
make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
|
||||
{
|
||||
#if 1
|
||||
return Merge_v2_magic_division<LowLengths>{low_lengths};
|
||||
#else
|
||||
return Merge_v2r2_magic_division<LowLengths>{low_lengths};
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename LowLengths>
|
||||
__host__ __device__ constexpr auto
|
||||
make_merge_transform_v3_division_mod(const LowLengths& low_lengths)
|
||||
{
|
||||
return Merge_v3_division_mod<LowLengths>{low_lengths};
|
||||
}
|
||||
|
||||
template <typename UpLengths, bool Use24BitIntegerCalculation = false>
|
||||
|
||||
@@ -189,8 +189,7 @@ struct TensorAdaptor
|
||||
bool is_known = true;
|
||||
|
||||
static_for<0, Transforms::Size(), 1>{}([&](auto i) {
|
||||
is_known &=
|
||||
remove_cv_t<remove_reference_t<decltype(Transforms{}[i])>>::IsKnownAtCompileTime();
|
||||
is_known &= remove_cvref_t<decltype(Transforms{}[i])>::IsKnownAtCompileTime();
|
||||
});
|
||||
|
||||
return is_known && is_known_at_compile_time<ElementSize>::value;
|
||||
|
||||
@@ -185,8 +185,7 @@ struct TensorDescriptor
|
||||
bool is_known = true;
|
||||
|
||||
static_for<0, Transforms::Size(), 1>{}([&](auto i) {
|
||||
is_known &=
|
||||
remove_cv_t<remove_reference_t<decltype(Transforms{}[i])>>::IsKnownAtCompileTime();
|
||||
is_known &= remove_cvref_t<decltype(Transforms{}[i])>::IsKnownAtCompileTime();
|
||||
});
|
||||
|
||||
return is_known && is_known_at_compile_time<ElementSize>::value &&
|
||||
@@ -587,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
|
||||
|
||||
template <typename TensorDesc>
|
||||
using TensorCoordinate_t = decltype(make_tensor_coordinate(
|
||||
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
|
||||
TensorDesc{}, MultiIndex<remove_cvref_t<TensorDesc>::GetNumOfDimension()>{}));
|
||||
|
||||
template <typename TensorDesc>
|
||||
using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step(
|
||||
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
|
||||
TensorDesc{}, MultiIndex<remove_cvref_t<TensorDesc>::GetNumOfDimension()>{}));
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -110,13 +110,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
|
||||
const BThreadBuffer& b_thread_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABlockBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatA>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename BThreadBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatB>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename CThreadBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatC>>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename ABlockBuffer::type>, remove_cvref_t<FloatA>>::value &&
|
||||
is_same<remove_cvref_t<typename BThreadBuffer::type>, remove_cvref_t<FloatB>>::value &&
|
||||
is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
|
||||
@@ -4,21 +4,22 @@
|
||||
#include "common_header.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "xdlops_gemm.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
class ABlockDesc,
|
||||
class BBlockDesc,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
typename FloatAcc,
|
||||
typename AK0MK1BlockDesc,
|
||||
typename BK0NK1BlockDesc,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t K1>
|
||||
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
|
||||
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
{
|
||||
|
||||
using CIndex = MultiIndex<2>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
@@ -26,111 +27,169 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
|
||||
|
||||
static constexpr index_t WaveSize = 64;
|
||||
|
||||
static constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
|
||||
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
|
||||
|
||||
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
|
||||
static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0);
|
||||
|
||||
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{};
|
||||
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
|
||||
|
||||
static constexpr index_t MWaves = M1 / MPerWave;
|
||||
static constexpr index_t NWaves = N1 / NPerWave;
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
static constexpr index_t MRepeat = M0;
|
||||
static constexpr index_t NRepeat = N0;
|
||||
StaticBufferV2<AddressSpaceEnum_t::Vgpr, vector_type<FloatAcc, 16>, MRepeat * NRepeat, true>
|
||||
c_thread_buf_;
|
||||
|
||||
__device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); }
|
||||
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
|
||||
|
||||
__device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); }
|
||||
__device__ static auto GetWaveIdx()
|
||||
{
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
|
||||
__device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); }
|
||||
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
}
|
||||
|
||||
__device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
const index_t waveId = thread_id / WaveSize;
|
||||
const index_t laneId = thread_id % WaveSize;
|
||||
const index_t waveId_m = waveId / NWaves;
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
if constexpr(xdlops_gemm.IsKReduction)
|
||||
{
|
||||
const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId);
|
||||
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
|
||||
return make_tuple(k_offset, 0, m_offset, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t m_offset = waveId_m * MPerWave + laneId;
|
||||
const index_t k_offset = 0;
|
||||
return make_tuple(k_offset, 0, m_offset, 0);
|
||||
}
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
|
||||
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(xdlops_a_idx[I0], 0, waveId_m, xdlops_a_idx[I1], 0);
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
const index_t waveId = thread_id / WaveSize;
|
||||
const index_t laneId = thread_id % WaveSize;
|
||||
const index_t waveId_n = waveId % NWaves;
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
if constexpr(xdlops_gemm.IsKReduction)
|
||||
{
|
||||
const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId);
|
||||
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
|
||||
return make_tuple(k_offset, 0, n_offset, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t n_offset = waveId_n * NPerWave + laneId;
|
||||
const index_t k_offset = 0;
|
||||
return make_tuple(k_offset, 0, n_offset, 0);
|
||||
}
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(xdlops_b_idx[I0], 0, waveId_n, xdlops_b_idx[I1], 0);
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static CIndex
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const index_t waveId = get_thread_local_1d_id() / WaveSize;
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
|
||||
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
|
||||
|
||||
const index_t waveId_m = waveId / NWaves;
|
||||
const index_t waveId_n = waveId % NWaves;
|
||||
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0];
|
||||
const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1];
|
||||
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
return CIndex{m_offset, n_offset};
|
||||
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
|
||||
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
|
||||
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
|
||||
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
|
||||
|
||||
return make_tuple(c_thread_m, c_thread_n);
|
||||
}
|
||||
|
||||
__device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1()
|
||||
: a_thread_copy_{CalculateAThreadOriginDataIndex()},
|
||||
b_thread_copy_{CalculateBThreadOriginDataIndex()}
|
||||
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
|
||||
{
|
||||
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
|
||||
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
|
||||
BK0NK1BlockDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent");
|
||||
static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0),
|
||||
"wrong! K0 dimension not consistent");
|
||||
|
||||
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3),
|
||||
static_assert(AK0MK1BlockDesc{}.GetLength(I2) == BK0NK1BlockDesc{}.GetLength(I2),
|
||||
"wrong! K1 dimension not consistent");
|
||||
|
||||
static_assert(BlockSize == MWaves * NWaves * WaveSize,
|
||||
"BlockSize != MWaves * NWaves * WaveSize\n");
|
||||
|
||||
static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!");
|
||||
|
||||
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
|
||||
|
||||
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
|
||||
|
||||
static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
|
||||
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor()
|
||||
{
|
||||
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
|
||||
|
||||
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
|
||||
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
|
||||
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
|
||||
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor()
|
||||
{
|
||||
constexpr auto c_m0_n0_m1_n1_m2_n2_block_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<NWaves>{},
|
||||
Number<MPerXDL>{},
|
||||
Number<NPerXDL>{}));
|
||||
|
||||
return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_block_desc);
|
||||
}
|
||||
|
||||
template <typename CMNGridDesc>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto c_m0_n0_m1_n1_m2_n2_grid_desc = transform_tensor_descriptor(
|
||||
c_m_n_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)),
|
||||
make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
|
||||
|
||||
return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeAK0M0M1M2K1BlockDescriptor()
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
AK0MK1BlockDesc{},
|
||||
make_tuple(make_pass_through_transform(Number<K0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{})),
|
||||
make_pass_through_transform(Number<K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeBK0N0N1N2K1BlockDescriptor()
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
BK0NK1BlockDesc{},
|
||||
make_tuple(make_pass_through_transform(Number<K0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerXDL>{})),
|
||||
make_pass_through_transform(Number<K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_k0_m0_m1_m2_k1_block_desc = MakeAK0M0M1M2K1BlockDescriptor();
|
||||
static constexpr auto b_k0_n0_n1_n2_k1_block_desc = MakeBK0N0N1N2K1BlockDescriptor();
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
@@ -141,49 +200,43 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
|
||||
|
||||
vector_type<FloatAB, a_thread_desc_.GetElementSpaceSize()> a_thread_vec;
|
||||
|
||||
vector_type<FloatAB, b_thread_desc_.GetElementSpaceSize()> b_thread_vec;
|
||||
|
||||
static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
// read A
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(k, I0, I0, I0),
|
||||
a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc,
|
||||
make_tuple(I0, m0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
make_tuple(I0, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// read B
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(k, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read B
|
||||
b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc,
|
||||
make_tuple(I0, n0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<FloatAB, xdlops_gemm.mfma_type.k_base>::type;
|
||||
static_for<0, K0, xdlops_gemm.K0PerXdlops>{}([&](auto k0) {
|
||||
vector_type<FloatAB, K1> a_thread_vec;
|
||||
vector_type<FloatAB, K1> b_thread_vec;
|
||||
|
||||
static_for<0, a_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatAB>()(Number<i>{}) = a_thread_buf[Number<i>{}];
|
||||
});
|
||||
static_for<0, K1, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
|
||||
[Number<a_thread_desc_.CalculateOffset(make_tuple(k0, 0, 0, 0, i))>{}];
|
||||
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
|
||||
[Number<b_thread_desc_.CalculateOffset(make_tuple(k0, 0, 0, 0, i))>{}];
|
||||
});
|
||||
|
||||
static_for<0, b_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) {
|
||||
b_thread_vec.template AsType<FloatAB>()(Number<i>{}) = b_thread_buf[Number<i>{}];
|
||||
});
|
||||
using mfma_input_type =
|
||||
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
xdlops_gemm.template Run<decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
decltype(c_thread_desc_),
|
||||
m0,
|
||||
n0>(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf);
|
||||
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0));
|
||||
|
||||
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVector(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -192,332 +245,37 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
|
||||
private:
|
||||
// A[K, M]
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{}));
|
||||
|
||||
// B[K, N]
|
||||
static constexpr auto b_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{}));
|
||||
|
||||
static constexpr auto c_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
ABlockDesc,
|
||||
decltype(a_k0_m0_m1_m2_k1_block_desc),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, MRepeat, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
Sequence<K0, 1, 1, 1, K1>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
K1,
|
||||
1>;
|
||||
K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
BBlockDesc,
|
||||
decltype(b_k0_n0_n1_n2_k1_block_desc),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, NRepeat, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
Sequence<K0, 1, 1, 1, K1>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
K1,
|
||||
1>;
|
||||
K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
class ABlockDesc,
|
||||
class BBlockDesc,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t K1>
|
||||
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
|
||||
{
|
||||
|
||||
using CIndex = MultiIndex<2>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, K1>{};
|
||||
|
||||
static constexpr index_t WaveSize = 64;
|
||||
|
||||
static constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr index_t MWaves = M1 / MPerWave;
|
||||
static constexpr index_t NWaves = N1 / NPerWave;
|
||||
|
||||
static constexpr index_t MRepeat = M0;
|
||||
static constexpr index_t NRepeat = N0;
|
||||
|
||||
__device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); }
|
||||
|
||||
__device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); }
|
||||
|
||||
__device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); }
|
||||
|
||||
__device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
const index_t waveId = thread_id / WaveSize;
|
||||
const index_t laneId = thread_id % WaveSize;
|
||||
const index_t waveId_m = waveId / NWaves;
|
||||
|
||||
if constexpr(xdlops_gemm.IsKReduction)
|
||||
{
|
||||
const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId);
|
||||
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
|
||||
return make_tuple(k_offset, 0, m_offset, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t m_offset = waveId_m * MPerWave + laneId;
|
||||
const index_t k_offset = 0;
|
||||
return make_tuple(k_offset, 0, m_offset, 0);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
const index_t waveId = thread_id / WaveSize;
|
||||
const index_t laneId = thread_id % WaveSize;
|
||||
const index_t waveId_n = waveId % NWaves;
|
||||
|
||||
if constexpr(xdlops_gemm.IsKReduction)
|
||||
{
|
||||
const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId);
|
||||
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
|
||||
return make_tuple(k_offset, 0, n_offset, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t n_offset = waveId_n * NPerWave + laneId;
|
||||
const index_t k_offset = 0;
|
||||
return make_tuple(k_offset, 0, n_offset, 0);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static CIndex
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
|
||||
const index_t waveId = get_thread_local_1d_id() / WaveSize;
|
||||
|
||||
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
|
||||
|
||||
const index_t waveId_m = waveId / NWaves;
|
||||
const index_t waveId_n = waveId % NWaves;
|
||||
|
||||
const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0];
|
||||
const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1];
|
||||
|
||||
return CIndex{m_offset, n_offset};
|
||||
}
|
||||
|
||||
__device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline()
|
||||
: a_thread_copy_{CalculateAThreadOriginDataIndex()},
|
||||
b_thread_copy_{CalculateBThreadOriginDataIndex()}
|
||||
{
|
||||
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent");
|
||||
|
||||
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3),
|
||||
"wrong! K1 dimension not consistent");
|
||||
|
||||
static_assert(BlockSize == MWaves * NWaves * WaveSize,
|
||||
"BlockSize != MWaves * NWaves * WaveSize\n");
|
||||
|
||||
static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!");
|
||||
|
||||
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
|
||||
|
||||
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
|
||||
|
||||
static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
|
||||
}
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
|
||||
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
xdlops_gemm.template Run<decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
decltype(c_thread_desc_),
|
||||
0,
|
||||
0>(a_thread_buf, b_thread_buf, c_thread_buf);
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
xdlops_gemm.template Run<decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
decltype(c_thread_desc_),
|
||||
0,
|
||||
1>(a_thread_buf, b_thread_buf, c_thread_buf);
|
||||
|
||||
static_for<xdlops_gemm.KPerXdlops, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) {
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(k, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
xdlops_gemm.template Run<decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
decltype(c_thread_desc_),
|
||||
1,
|
||||
0>(a_thread_buf, b_thread_buf, c_thread_buf);
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(k, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
xdlops_gemm.template Run<decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
decltype(c_thread_desc_),
|
||||
1,
|
||||
1>(a_thread_buf, b_thread_buf, c_thread_buf);
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(k, I1, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(k, I1, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
xdlops_gemm.template Run<decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
decltype(c_thread_desc_),
|
||||
0,
|
||||
0>(a_thread_buf, b_thread_buf, c_thread_buf);
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
xdlops_gemm.template Run<decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
decltype(c_thread_desc_),
|
||||
0,
|
||||
1>(a_thread_buf, b_thread_buf, c_thread_buf);
|
||||
});
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
xdlops_gemm.template Run<decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
decltype(c_thread_desc_),
|
||||
1,
|
||||
0>(a_thread_buf, b_thread_buf, c_thread_buf);
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
xdlops_gemm.template Run<decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
decltype(c_thread_desc_),
|
||||
1,
|
||||
1>(a_thread_buf, b_thread_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
private:
|
||||
// A[K, M]
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
|
||||
|
||||
// B[K, N]
|
||||
static constexpr auto b_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
|
||||
|
||||
static constexpr auto c_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
ABlockDesc,
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
1, // K1,
|
||||
1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
BBlockDesc,
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
1, // K1,
|
||||
1>;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
|
||||
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -18,8 +18,9 @@ template <typename GridwiseGemm,
|
||||
typename FloatC,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CM0M1M2NGridDesc,
|
||||
typename CBlockClusterAdaptor>
|
||||
typename CM0N0M1N1M2M3M4N2GridDesc,
|
||||
typename CBlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -29,7 +30,7 @@ __global__ void
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AK0MK1GridDesc a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc b_k0_n_k1_grid_desc,
|
||||
const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc,
|
||||
const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
@@ -37,14 +38,14 @@ __global__ void
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
template <typename GridwiseGemm,
|
||||
@@ -52,7 +53,7 @@ template <typename GridwiseGemm,
|
||||
typename FloatC,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CM0M1M2NGridDesc,
|
||||
typename CM0N0M1N1M2M3M4N2GridDesc,
|
||||
typename CBlockClusterAdaptor>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
@@ -63,7 +64,7 @@ __global__ void
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
const void CONSTANT* p_c_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
@@ -73,21 +74,22 @@ __global__ void
|
||||
cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc));
|
||||
const auto b_k0_n_k1_grid_desc = *reinterpret_cast<const BK0NK1GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc));
|
||||
const auto c_m0_m1_m2_n_grid_desc = *reinterpret_cast<const CM0M1M2NGridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_c_m0_m1_m2_n_grid_desc));
|
||||
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
|
||||
*reinterpret_cast<const CM0N0M1N1M2M3M4N2GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc));
|
||||
const auto c_block_cluster_adaptor = *reinterpret_cast<const CBlockClusterAdaptor*>(
|
||||
cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor));
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -101,9 +103,9 @@ template <index_t BlockSize,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t K0PerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t K1Value,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
@@ -131,13 +133,19 @@ template <index_t BlockSize,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks,
|
||||
bool CAccessOrderMRepeatNRepeat>
|
||||
bool CAccessOrderMRepeatNRepeat,
|
||||
bool ABlockLdsExtraM,
|
||||
bool BBlockLdsExtraN>
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
@@ -147,14 +155,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
@@ -166,27 +194,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc)
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
index_t M01,
|
||||
index_t N01)
|
||||
{
|
||||
// TODO: turn on this
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
|
||||
"wrong! K1 need to be known at compile-time");
|
||||
|
||||
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
|
||||
(NPerBlock % (NRepeat * NPerXDL)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
|
||||
K1 == b_k0_n_k1_grid_desc.GetLength(I2)))
|
||||
return false;
|
||||
|
||||
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
|
||||
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
|
||||
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
|
||||
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0) &&
|
||||
(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0);
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check M01, N01
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
if(!(M0 % M01 == 0 && N0 % N01 == 0))
|
||||
return false;
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
@@ -200,34 +246,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
|
||||
{
|
||||
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{};
|
||||
const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1;
|
||||
|
||||
constexpr auto CLayout = xdlops_gemm.GetCLayout();
|
||||
|
||||
constexpr auto M0 = Number<CLayout.M1()>{};
|
||||
constexpr auto M1 = Number<CLayout.N1()>{};
|
||||
constexpr auto M2 = Number<CLayout.M0()>{};
|
||||
|
||||
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
|
||||
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
|
||||
|
||||
constexpr auto N1 = Number<CLayout.N0()>{};
|
||||
|
||||
const auto c_m0_m1_m2_n_grid_desc = transform_tensor_descriptor(
|
||||
c_m_n_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)),
|
||||
make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
|
||||
|
||||
return c_m0_m1_m2_n_grid_desc;
|
||||
return has_main_k0_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
using BlockwiseGemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
K1>;
|
||||
|
||||
return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
@@ -238,31 +316,40 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
#if 1
|
||||
const auto M00 = M0 / M01;
|
||||
const auto N00 = N0 / N01;
|
||||
|
||||
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
|
||||
make_unmerge_transform(make_tuple(N00, N01))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
|
||||
|
||||
const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
#elif 1
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, M0))),
|
||||
make_tuple(Sequence<1, 0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
#endif
|
||||
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
|
||||
c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor);
|
||||
|
||||
return c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using CM0M1M2NGridDesc = decltype(MakeCM0M1M2NGridDescriptor(CMNGridDesc{}));
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}));
|
||||
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1));
|
||||
|
||||
template <bool HasMainKBlockLoop>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc,
|
||||
const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
const CBlockClusterAdaptor& c_block_cluster_adaptor)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
@@ -270,7 +357,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize());
|
||||
p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
@@ -289,20 +376,40 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, MPerBlock, K1>,
|
||||
Sequence<K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -328,7 +435,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, NPerBlock, K1>,
|
||||
Sequence<K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -352,59 +459,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[KPerBlock, NPerBlock] is in LDS
|
||||
// a_mtx[K0PerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[K0PerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
|
||||
static_assert(MPerBlock % (MPerWave * MRepeat) == 0 &&
|
||||
NPerBlock % (NPerWave * NRepeat) == 0,
|
||||
"wrong!");
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
K1>{};
|
||||
|
||||
constexpr auto a_k0_m0_m1_k1_block_desc = transform_tensor_descriptor(
|
||||
a_k0_m_k1_block_desc,
|
||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{})),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto b_k0_n0_n1_k1_block_desc = transform_tensor_descriptor(
|
||||
b_k0_n_k1_block_desc,
|
||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{})),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
|
||||
FloatAB,
|
||||
decltype(a_k0_m0_m1_k1_block_desc),
|
||||
decltype(b_k0_n0_n1_k1_block_desc),
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1>{};
|
||||
|
||||
constexpr auto CLayout = blockwise_gemm.GetCLayout();
|
||||
|
||||
constexpr index_t BlkSize = CLayout.GetBlkSize();
|
||||
constexpr index_t NumBlks = CLayout.GetNumBlks();
|
||||
constexpr index_t NumXdlops = CLayout.GetNumXdlops();
|
||||
|
||||
static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only");
|
||||
|
||||
constexpr auto c_mr_nr_blk_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
vector_type<FloatAcc, BlkSize>,
|
||||
c_mr_nr_blk_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
c_thread_buf;
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
@@ -413,8 +486,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
FloatAB* p_a_block = p_shared_block;
|
||||
FloatAB* p_b_block = p_shared_block + a_block_space_size;
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
|
||||
@@ -440,32 +513,37 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
}
|
||||
|
||||
// main body
|
||||
index_t k_block_data_begin = 0;
|
||||
index_t k0_block_data_begin = 0;
|
||||
|
||||
do
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hack);
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hack);
|
||||
|
||||
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||
|
||||
block_sync_lds();
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
|
||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
|
||||
|
||||
k_block_data_begin += KPerBlock;
|
||||
} while(k_block_data_begin < (K0 - KPerBlock));
|
||||
k0_block_data_begin += K0PerBlock;
|
||||
} while(k0_block_data_begin < (K0 - K0PerBlock));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
@@ -474,41 +552,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
#if 0
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr index_t M0 = CLayout.M1();
|
||||
constexpr index_t M1 = CLayout.N1();
|
||||
constexpr index_t M2 = CLayout.M0();
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
|
||||
blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor();
|
||||
|
||||
constexpr index_t N0 = CLayout.N1();
|
||||
constexpr index_t N1 = CLayout.N0();
|
||||
constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
|
||||
constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
|
||||
constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
|
||||
constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
|
||||
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
|
||||
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
|
||||
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
|
||||
constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_thread_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<1>{},
|
||||
Number<1>{},
|
||||
Number<M0>{},
|
||||
Number<1>{},
|
||||
Number<M2>{},
|
||||
Number<1>{}));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize(), true>
|
||||
c_blk_buf_;
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
|
||||
constexpr auto blk_off =
|
||||
c_mr_nr_blk_desc.CalculateOffset(make_tuple(mr_i, nr_i));
|
||||
|
||||
static_for<0, BlkSize, 1>{}([&](auto j) {
|
||||
c_blk_buf_(Number<blk_off * BlkSize + j>{}) =
|
||||
c_thread_buf[Number<blk_off>{}]
|
||||
.template AsType<FloatAcc>()[Number<j>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<M0>{}, Number<N0>{}, I1, I1, Number<M2>{}, I1, Number<M4>{}, I1));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
@@ -521,277 +581,57 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
const index_t n_thread_data_on_grid =
|
||||
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
|
||||
|
||||
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
|
||||
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
|
||||
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatC,
|
||||
FloatC,
|
||||
decltype(c_m0_m1_m2_n_thread_desc),
|
||||
decltype(c_m0_m1_m2_n_grid_desc),
|
||||
Sequence<MRepeat, NRepeat, 1, 1, M0, 1, M2, 1>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
make_multi_index(m_thread_data_on_grid / (M2 * M1 * M0 * MWaves),
|
||||
n_thread_data_on_grid / (N1 * NWaves),
|
||||
m_thread_data_on_grid % (M2 * M1 * M0 * MWaves) / (M2 * M1 * M0),
|
||||
n_thread_data_on_grid % (N1 * NWaves) / N1,
|
||||
m_thread_data_on_grid % (M2 * M1 * M0) / (M2 * M1),
|
||||
m_thread_data_on_grid % (M2 * M1) / M2,
|
||||
m_thread_data_on_grid % M2,
|
||||
n_thread_data_on_grid % N1)}
|
||||
.Run(c_m0_m1_m2_n_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_blk_buf_,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
}
|
||||
#else
|
||||
{
|
||||
constexpr index_t M0 = CLayout.M1();
|
||||
constexpr index_t M1 = CLayout.N1();
|
||||
constexpr index_t M2 = CLayout.M0();
|
||||
const auto m_thread_data_on_grid_idx =
|
||||
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_grid));
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, I1, I1, I1, Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
|
||||
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_grid =
|
||||
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
|
||||
|
||||
const index_t n_thread_data_on_grid =
|
||||
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
|
||||
const auto n_thread_data_on_grid_idx =
|
||||
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_grid));
|
||||
|
||||
auto c_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatC,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_m1_m2_n_thread_desc),
|
||||
decltype(c_m0_m1_m2_n_grid_desc),
|
||||
Sequence<1, 1, 1, 1, M0, 1, M2, 1>,
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
|
||||
Sequence<M0, N0, I1, I1, M2, I1, M4, I1>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
m_thread_data_on_grid / (M2 * M1),
|
||||
m_thread_data_on_grid % (M2 * M1) / M2,
|
||||
m_thread_data_on_grid % M2,
|
||||
n_thread_data_on_grid)};
|
||||
|
||||
auto init_copy = [&](auto c_thread_idx_) {
|
||||
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
|
||||
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
make_multi_index(m_thread_data_on_grid_idx[I0],
|
||||
n_thread_data_on_grid_idx[I0],
|
||||
m_thread_data_on_grid_idx[I1],
|
||||
n_thread_data_on_grid_idx[I1],
|
||||
m_thread_data_on_grid_idx[I2],
|
||||
m_thread_data_on_grid_idx[I3],
|
||||
m_thread_data_on_grid_idx[I4],
|
||||
n_thread_data_on_grid_idx[I2])};
|
||||
|
||||
return c_thread_idx_;
|
||||
};
|
||||
|
||||
auto mrepeat_plus_copy = [&](auto c_thread_idx_) {
|
||||
constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
|
||||
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus);
|
||||
|
||||
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
|
||||
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
|
||||
constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0);
|
||||
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_plus);
|
||||
|
||||
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
|
||||
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
|
||||
constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0);
|
||||
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus);
|
||||
|
||||
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
|
||||
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
|
||||
constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0);
|
||||
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_minus);
|
||||
|
||||
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
|
||||
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
|
||||
(MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or
|
||||
(MRepeat == 2 && NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or
|
||||
(MRepeat == 1 && NRepeat == 1),
|
||||
"wrong");
|
||||
|
||||
if constexpr(MRepeat == 4 && NRepeat == 4)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
|
||||
if constexpr(CAccessOrderMRepeatNRepeat)
|
||||
{
|
||||
nrepeat_plus_copy(make_tuple(I0, I1));
|
||||
nrepeat_plus_copy(make_tuple(I0, I2));
|
||||
nrepeat_plus_copy(make_tuple(I0, I3));
|
||||
mrepeat_plus_copy(make_tuple(I1, I3));
|
||||
nrepeat_minus_copy(make_tuple(I1, I2));
|
||||
nrepeat_minus_copy(make_tuple(I1, I1));
|
||||
nrepeat_minus_copy(make_tuple(I1, I0));
|
||||
mrepeat_plus_copy(make_tuple(I2, I0));
|
||||
nrepeat_plus_copy(make_tuple(I2, I1));
|
||||
nrepeat_plus_copy(make_tuple(I2, I2));
|
||||
nrepeat_plus_copy(make_tuple(I2, I3));
|
||||
mrepeat_plus_copy(make_tuple(I3, I3));
|
||||
nrepeat_minus_copy(make_tuple(I3, I2));
|
||||
nrepeat_minus_copy(make_tuple(I3, I1));
|
||||
nrepeat_minus_copy(make_tuple(I3, I0));
|
||||
}
|
||||
else
|
||||
{
|
||||
mrepeat_plus_copy(make_tuple(I1, I0));
|
||||
mrepeat_plus_copy(make_tuple(I2, I0));
|
||||
mrepeat_plus_copy(make_tuple(I3, I0));
|
||||
nrepeat_plus_copy(make_tuple(I3, I1));
|
||||
mrepeat_minus_copy(make_tuple(I2, I1));
|
||||
mrepeat_minus_copy(make_tuple(I1, I1));
|
||||
mrepeat_minus_copy(make_tuple(I0, I1));
|
||||
nrepeat_plus_copy(make_tuple(I0, I2));
|
||||
mrepeat_plus_copy(make_tuple(I1, I2));
|
||||
mrepeat_plus_copy(make_tuple(I2, I2));
|
||||
mrepeat_plus_copy(make_tuple(I3, I2));
|
||||
nrepeat_plus_copy(make_tuple(I3, I3));
|
||||
mrepeat_minus_copy(make_tuple(I2, I3));
|
||||
mrepeat_minus_copy(make_tuple(I1, I3));
|
||||
mrepeat_minus_copy(make_tuple(I0, I3));
|
||||
}
|
||||
}
|
||||
else if constexpr(MRepeat == 4 && NRepeat == 2)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
|
||||
if constexpr(CAccessOrderMRepeatNRepeat)
|
||||
{
|
||||
nrepeat_plus_copy(make_tuple(I0, I1));
|
||||
mrepeat_plus_copy(make_tuple(I1, I1));
|
||||
nrepeat_minus_copy(make_tuple(I1, I0));
|
||||
mrepeat_plus_copy(make_tuple(I2, I0));
|
||||
nrepeat_plus_copy(make_tuple(I2, I1));
|
||||
mrepeat_plus_copy(make_tuple(I3, I1));
|
||||
nrepeat_minus_copy(make_tuple(I3, I0));
|
||||
}
|
||||
else
|
||||
{
|
||||
mrepeat_plus_copy(make_tuple(I1, I0));
|
||||
mrepeat_plus_copy(make_tuple(I2, I0));
|
||||
mrepeat_plus_copy(make_tuple(I3, I0));
|
||||
nrepeat_plus_copy(make_tuple(I3, I1));
|
||||
mrepeat_minus_copy(make_tuple(I2, I1));
|
||||
mrepeat_minus_copy(make_tuple(I1, I1));
|
||||
mrepeat_minus_copy(make_tuple(I0, I1));
|
||||
}
|
||||
}
|
||||
else if constexpr(MRepeat == 2 && NRepeat == 4)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
|
||||
if constexpr(CAccessOrderMRepeatNRepeat)
|
||||
{
|
||||
nrepeat_plus_copy(make_tuple(I0, I1));
|
||||
nrepeat_plus_copy(make_tuple(I0, I2));
|
||||
nrepeat_plus_copy(make_tuple(I0, I3));
|
||||
mrepeat_plus_copy(make_tuple(I1, I3));
|
||||
nrepeat_minus_copy(make_tuple(I1, I2));
|
||||
nrepeat_minus_copy(make_tuple(I1, I1));
|
||||
nrepeat_minus_copy(make_tuple(I1, I0));
|
||||
}
|
||||
else
|
||||
{
|
||||
mrepeat_plus_copy(make_tuple(I1, I0));
|
||||
nrepeat_plus_copy(make_tuple(I1, I1));
|
||||
mrepeat_minus_copy(make_tuple(I0, I1));
|
||||
nrepeat_plus_copy(make_tuple(I0, I2));
|
||||
mrepeat_plus_copy(make_tuple(I1, I2));
|
||||
nrepeat_plus_copy(make_tuple(I1, I3));
|
||||
mrepeat_minus_copy(make_tuple(I0, I3));
|
||||
}
|
||||
}
|
||||
else if constexpr(MRepeat == 2 && NRepeat == 2)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
|
||||
if constexpr(CAccessOrderMRepeatNRepeat)
|
||||
{
|
||||
nrepeat_plus_copy(make_tuple(I0, I1));
|
||||
mrepeat_plus_copy(make_tuple(I1, I1));
|
||||
nrepeat_minus_copy(make_tuple(I1, I0));
|
||||
}
|
||||
else
|
||||
{
|
||||
mrepeat_plus_copy(make_tuple(I1, I0));
|
||||
nrepeat_plus_copy(make_tuple(I1, I1));
|
||||
mrepeat_minus_copy(make_tuple(I0, I1));
|
||||
}
|
||||
}
|
||||
else if constexpr(MRepeat == 2 && NRepeat == 1)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
mrepeat_plus_copy(make_tuple(I1, I0));
|
||||
}
|
||||
else if constexpr(MRepeat == 1 && NRepeat == 2)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
nrepeat_plus_copy(make_tuple(I0, I1));
|
||||
}
|
||||
else if constexpr(MRepeat == 1 && NRepeat == 1)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
}
|
||||
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}; // namespace ck
|
||||
|
||||
|
||||
@@ -0,0 +1,666 @@
|
||||
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP
|
||||
#define CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_xdlops.hpp"
|
||||
#include "blockwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename ABK0MK1GridDesc,
|
||||
typename BBK0NK1GridDesc,
|
||||
typename CM0N0M1N1M2M3M4N2GridDesc,
|
||||
typename CBlockClusterAdaptor>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc,
|
||||
const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc,
|
||||
const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename ABK0MK1GridDesc,
|
||||
typename BBK0NK1GridDesc,
|
||||
typename CM0N0M1N1M2M3M4N2GridDesc,
|
||||
typename CBlockClusterAdaptor>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_b_k0_m_k1_grid_desc,
|
||||
const void CONSTANT* p_b_b_k0_n_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
const void CONSTANT* p_c_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
const auto a_b_k0_m_k1_grid_desc = *reinterpret_cast<const ABK0MK1GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_a_b_k0_m_k1_grid_desc));
|
||||
const auto b_b_k0_n_k1_grid_desc = *reinterpret_cast<const BBK0NK1GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_b_b_k0_n_k1_grid_desc));
|
||||
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
|
||||
*reinterpret_cast<const CM0N0M1N1M2M3M4N2GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc));
|
||||
const auto c_block_cluster_adaptor = *reinterpret_cast<const CBlockClusterAdaptor*>(
|
||||
cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor));
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename ABK0MK1GridDesc,
|
||||
typename BBK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t K1Value,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks,
|
||||
bool CAccessOrderMRepeatNRepeat,
|
||||
bool ABlockLdsExtraM,
|
||||
bool BBlockLdsExtraN>
|
||||
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size =
|
||||
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
|
||||
const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
index_t M01,
|
||||
index_t N01)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
|
||||
"wrong! K1 need to be known at compile-time");
|
||||
|
||||
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
|
||||
(NPerBlock % (NRepeat * NPerXDL)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
|
||||
const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
|
||||
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
|
||||
K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
|
||||
K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) &&
|
||||
KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)))
|
||||
return false;
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check M01, N01
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
if(!(M0 % M01 == 0 && N0 % N01 == 0))
|
||||
return false;
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock) * KBatch;
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
using BlockwiseGemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
K1>;
|
||||
|
||||
return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
|
||||
const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto M00 = M0 / M01;
|
||||
const auto N00 = N0 / N01;
|
||||
|
||||
const auto kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_pass_through_transform(KBatch),
|
||||
make_unmerge_transform(make_tuple(M00, M01)),
|
||||
make_unmerge_transform(make_tuple(N00, N01))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
|
||||
|
||||
const auto c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(KBatch, M00, N00, M01, N01))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto c_blockid_to_kbatch_m0_n0_block_cluster_adaptor =
|
||||
chain_tensor_adaptors(kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
|
||||
c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor);
|
||||
|
||||
return c_blockid_to_kbatch_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
|
||||
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
|
||||
const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
|
||||
const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
const CBlockClusterAdaptor& c_block_cluster_adaptor)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const index_t k_batch_id = block_work_idx[I0];
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto a_b_k0_m_k1_block_desc = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<KPerBlock>{} * Number<MPerBlock + 1>{} * K1,
|
||||
Number<MPerBlock + 1>{} * K1,
|
||||
K1,
|
||||
I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
max_lds_align);
|
||||
}
|
||||
}();
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto b_b_k0_n_k1_block_desc = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<KPerBlock>{} * Number<NPerBlock + 1>{} * K1,
|
||||
Number<NPerBlock + 1>{} * K1,
|
||||
K1,
|
||||
I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
max_lds_align);
|
||||
}
|
||||
}();
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<1, KPerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_b_k0_m_k1_grid_desc),
|
||||
decltype(a_b_k0_m_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
3,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
|
||||
a_b_k0_m_k1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<1, KPerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_b_k0_n_k1_grid_desc),
|
||||
decltype(b_b_k0_n_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
|
||||
b_b_k0_n_k1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[KPerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
K1>{};
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block = p_shared_block;
|
||||
FloatAB* p_b_block = p_shared_block + a_block_space_size;
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
|
||||
}
|
||||
|
||||
// main body
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hack);
|
||||
|
||||
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
|
||||
|
||||
k_block_data_begin += KPerBlock;
|
||||
} while(k_block_data_begin < (K0 - KPerBlock));
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
|
||||
blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor();
|
||||
|
||||
constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
|
||||
constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
|
||||
constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
|
||||
constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
|
||||
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
|
||||
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
|
||||
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
|
||||
constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<M0>{}, Number<N0>{}, I1, I1, Number<M2>{}, I1, Number<M4>{}, I1));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_grid =
|
||||
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
|
||||
|
||||
const index_t n_thread_data_on_grid =
|
||||
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
|
||||
|
||||
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_grid_idx =
|
||||
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_grid));
|
||||
|
||||
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_grid_idx =
|
||||
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_grid));
|
||||
|
||||
auto c_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
|
||||
Sequence<M0, N0, I1, I1, M2, I1, M4, I1>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{
|
||||
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
make_multi_index(m_thread_data_on_grid_idx[I0],
|
||||
n_thread_data_on_grid_idx[I0],
|
||||
m_thread_data_on_grid_idx[I1],
|
||||
n_thread_data_on_grid_idx[I1],
|
||||
m_thread_data_on_grid_idx[I2],
|
||||
m_thread_data_on_grid_idx[I3],
|
||||
m_thread_data_on_grid_idx[I4],
|
||||
n_thread_data_on_grid_idx[I2])};
|
||||
|
||||
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
|
||||
}
|
||||
}
|
||||
}; // namespace ck
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -55,19 +55,16 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
|
||||
CBuffer& c_buf,
|
||||
COriginIdx)
|
||||
{
|
||||
static_assert(
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
|
||||
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
|
||||
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
|
||||
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
|
||||
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatA>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatB>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatC>>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
|
||||
is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
|
||||
is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -157,19 +154,16 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_
|
||||
CBuffer& c_buf,
|
||||
COriginIdx)
|
||||
{
|
||||
static_assert(
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
|
||||
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
|
||||
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
|
||||
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
|
||||
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatA>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatB>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatC>>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
|
||||
is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
|
||||
is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
@@ -41,19 +41,16 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
|
||||
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
|
||||
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
|
||||
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
|
||||
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatA>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatB>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatC>>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
|
||||
is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
|
||||
is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
@@ -30,11 +30,11 @@ struct ThreadwiseTensorSliceSet_v1
|
||||
|
||||
static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<remove_reference_t<OriginIdx>>>::value,
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<OriginIdx>>::value,
|
||||
"wrong! OriginIdx need to be known at compile-time");
|
||||
|
||||
// Desc is known at compile-time
|
||||
constexpr auto desc = remove_cv_t<remove_reference_t<Desc>>{};
|
||||
constexpr auto desc = remove_cvref_t<Desc>{};
|
||||
|
||||
// OriginIdx is known at compile-time
|
||||
constexpr auto origin_idx = to_multi_index(OriginIdx{});
|
||||
|
||||
@@ -95,18 +95,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
|
||||
static_assert(
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<SrcSliceOriginIdx>>>::value,
|
||||
"wrong! SrcSliceOrigin need to known at compile-time");
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
|
||||
"wrong! SrcSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
|
||||
|
||||
// static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
|
||||
// remove_cv_t<remove_reference_t<SrcData>>>::value,
|
||||
//"wrong! SrcBuffer data type is wrong");
|
||||
|
||||
// SrcDesc and src_slice_origin_idx are known at compile-time
|
||||
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
@@ -208,10 +203,20 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// copy data from dst_vector into dst_buf
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set)
|
||||
{
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
}
|
||||
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd)
|
||||
{
|
||||
dst_buf.template AtomicAdd<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
}
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
@@ -392,7 +397,7 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
{
|
||||
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
}
|
||||
@@ -411,16 +416,15 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
static_assert(DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! DstDesc need to known at compile-time");
|
||||
|
||||
static_assert(
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstSliceOriginIdx>>>::value,
|
||||
"wrong! DstSliceOrigin need to known at compile-time");
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
|
||||
"wrong! DstSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<DstData>>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
// DstDesc and dst_slice_origin_idx are known at compile-time
|
||||
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{};
|
||||
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
|
||||
constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
@@ -729,9 +733,9 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<SrcData>>>::value,
|
||||
"wrong! SrcBuffer and SrcData data type are inconsistent");
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
|
||||
"wrong! SrcBuffer and SrcData data type are inconsistent");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -886,9 +890,9 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<DstData>>>::value,
|
||||
"wrong! SrcBuffer or DstBuffer data type is wrong");
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
|
||||
"wrong! SrcBuffer or DstBuffer data type is wrong");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -1303,24 +1307,21 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<SrcData>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<DstData>>>::value,
|
||||
"wrong! SrcBuffer or DstBuffer data type is wrong");
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value &&
|
||||
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
|
||||
"wrong! SrcBuffer or DstBuffer data type is wrong");
|
||||
|
||||
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
|
||||
|
||||
static_assert(
|
||||
is_known_at_compile_time<
|
||||
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstOriginIdx>>>::value,
|
||||
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
|
||||
"at compile-time");
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcRefToOriginDisplacement>>::value &&
|
||||
is_known_at_compile_time<remove_cvref_t<DstOriginIdx>>::value,
|
||||
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
|
||||
"at compile-time");
|
||||
|
||||
// SrcDesc and DstDesc are known at compile-time
|
||||
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
|
||||
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{};
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
|
||||
|
||||
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
|
||||
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
|
||||
|
||||
@@ -80,9 +80,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<SrcData>>>::value,
|
||||
"wrong! SrcBuffer and SrcData data type are inconsistent");
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
|
||||
"wrong! SrcBuffer and SrcData data type are inconsistent");
|
||||
|
||||
// tensor descriptor for src_vector
|
||||
constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{};
|
||||
@@ -248,9 +248,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<DstData>>>::value,
|
||||
"wrong! SrcBuffer or DstBuffer data type is wrong");
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
|
||||
"wrong! SrcBuffer or DstBuffer data type is wrong");
|
||||
|
||||
// tensor descriptor for dst_vector
|
||||
constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{};
|
||||
@@ -669,24 +669,21 @@ struct ThreadwiseTensorSliceTransfer_v4r1
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<SrcData>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<DstData>>>::value,
|
||||
"wrong! SrcBuffer or DstBuffer data type is wrong");
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value &&
|
||||
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
|
||||
"wrong! SrcBuffer or DstBuffer data type is wrong");
|
||||
|
||||
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
|
||||
|
||||
static_assert(
|
||||
is_known_at_compile_time<
|
||||
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstOriginIdx>>>::value,
|
||||
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
|
||||
"at compile-time");
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcRefToOriginDisplacement>>::value &&
|
||||
is_known_at_compile_time<remove_cvref_t<DstOriginIdx>>::value,
|
||||
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
|
||||
"at compile-time");
|
||||
|
||||
// SrcDesc and DstDesc are known at compile-time
|
||||
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
|
||||
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{};
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
|
||||
|
||||
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
|
||||
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -202,6 +202,22 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
|
||||
// atomic add
|
||||
// int
|
||||
__device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
|
||||
int32_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32");
|
||||
|
||||
// float
|
||||
__device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
|
||||
float vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
|
||||
|
||||
template <typename T, index_t N>
|
||||
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
|
||||
@@ -624,8 +640,130 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
__device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset)
|
||||
{
|
||||
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(is_same<T, float>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
vector_type<float, 2> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(float),
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
vector_type<float, 4> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(float),
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<2>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 2 * sizeof(float),
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<3>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 3 * sizeof(float),
|
||||
0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
vector_type<int32_t, 2> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(int32_t),
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
vector_type<int32_t, 4> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(int32_t),
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<2>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 2 * sizeof(int32_t),
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<3>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 3 * sizeof(int32_t),
|
||||
0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// buffer_load requires:
|
||||
// 1) p_src_wave must be in global memory space
|
||||
// 1) p_src_wave must point to global memory space
|
||||
// 2) p_src_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
@@ -659,7 +797,7 @@ amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave,
|
||||
}
|
||||
|
||||
// buffer_load requires:
|
||||
// 1) p_src_wave must be in global memory space
|
||||
// 1) p_src_wave must point to global memory space
|
||||
// 2) p_src_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
@@ -687,8 +825,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
|
||||
}
|
||||
|
||||
// buffer_store requires:
|
||||
// 1) p_dst_wave must be global memory
|
||||
// 2) p_dst_wave to be a wavewise pointer.
|
||||
// 1) p_dst_wave must point to global memory
|
||||
// 2) p_dst_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
__device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
||||
@@ -720,5 +858,40 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
|
||||
#endif
|
||||
}
|
||||
|
||||
// buffer_atomic_add requires:
|
||||
// 1) p_dst_wave must point to global memory
|
||||
// 2) p_dst_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
__device__ void
|
||||
amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_element_offset,
|
||||
const bool dst_thread_element_valid,
|
||||
const index_t dst_element_space_size)
|
||||
{
|
||||
const int32x4_t dst_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
|
||||
|
||||
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
#else
|
||||
if(dst_thread_element_valid)
|
||||
{
|
||||
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -51,304 +51,196 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
|
||||
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x1f32;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x1f32<64, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
|
||||
1,
|
||||
0,
|
||||
0);
|
||||
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
|
||||
1,
|
||||
1,
|
||||
0);
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
||||
reg_c.template AsType<float32_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x1f32<32, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x1f32<32, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
|
||||
1,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x2f32;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x2f32<32, 32, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x2f32<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x4f32;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_16x16x4f32<16, 16, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x4f32<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x1f32;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_16x16x1f32<16, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x1f32<16, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
|
||||
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
|
||||
2,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_4x4x1f32;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_4x4x1f32<4, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_4x4x1f32<4, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
4,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_4x4x1f32<8, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_4x4x1f32<8, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
4,
|
||||
0,
|
||||
0);
|
||||
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
4,
|
||||
1,
|
||||
0);
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
||||
reg_c.template AsType<float4_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x4f16;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x4f16<64, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x4f16<64, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
|
||||
1,
|
||||
0,
|
||||
0);
|
||||
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
|
||||
1,
|
||||
1,
|
||||
0);
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
||||
reg_c.template AsType<float32_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x4f16<32, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x4f16<32, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
|
||||
1,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x8f16;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x8f16<32, 32, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x8f16<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x16f16;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_16x16x16f16<16, 16, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x16f16<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x4f16;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_16x16x4f16<16, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x4f16<16, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
|
||||
2,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_4x4x4f16;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_4x4x4f16<4, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_4x4x4f16<4, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
4,
|
||||
0,
|
||||
0);
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_4x4x4f16<8, 64, COffset>
|
||||
template <>
|
||||
struct intrin_mfma_f32_4x4x4f16<8, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
4,
|
||||
0,
|
||||
0);
|
||||
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
4,
|
||||
1,
|
||||
0);
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
||||
reg_c.template AsType<float4_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -448,7 +340,6 @@ template <index_t MPerWave, index_t NPerWave>
|
||||
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a,
|
||||
const ushort2_t* reg_b,
|
||||
c_vec16_1_t::VecType reg_c);
|
||||
|
||||
template <>
|
||||
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a,
|
||||
const ushort2_t* reg_b,
|
||||
|
||||
@@ -48,7 +48,7 @@ struct Array<TData, 0>
|
||||
template <typename X, typename... Xs>
|
||||
__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs)
|
||||
{
|
||||
using data_type = remove_cv_t<remove_reference_t<X>>;
|
||||
using data_type = remove_cvref_t<X>;
|
||||
return Array<data_type, sizeof...(Xs) + 1>{{std::forward<X>(x), std::forward<Xs>(xs)...}};
|
||||
}
|
||||
|
||||
|
||||
@@ -85,13 +85,13 @@
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK 1
|
||||
#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
// pass tensor descriptor by value or void*
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0
|
||||
|
||||
// merge transformation use magic number division
|
||||
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0
|
||||
|
||||
@@ -43,18 +43,15 @@ struct DynamicBuffer
|
||||
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
|
||||
|
||||
template <typename X,
|
||||
typename enable_if<
|
||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector =
|
||||
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector =
|
||||
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X need to be multiple T");
|
||||
@@ -71,15 +68,14 @@ struct DynamicBuffer
|
||||
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return amd_buffer_load_invalid_element_return_return_zero<
|
||||
remove_cv_t<remove_reference_t<T>>,
|
||||
t_per_x>(p_data_, i, is_valid_element, element_space_size_);
|
||||
return amd_buffer_load_invalid_element_return_return_zero<remove_cvref_t<T>,
|
||||
t_per_x>(
|
||||
p_data_, i, is_valid_element, element_space_size_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return amd_buffer_load_invalid_element_return_customized_value<
|
||||
remove_cv_t<remove_reference_t<T>>,
|
||||
t_per_x>(
|
||||
return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
|
||||
t_per_x>(
|
||||
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
|
||||
}
|
||||
}
|
||||
@@ -98,18 +94,15 @@ struct DynamicBuffer
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename enable_if<
|
||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector =
|
||||
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector =
|
||||
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X need to be multiple T");
|
||||
@@ -119,7 +112,7 @@ struct DynamicBuffer
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_store<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
||||
amd_buffer_store<remove_cvref_t<T>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, element_space_size_);
|
||||
#else
|
||||
if(is_valid_element)
|
||||
@@ -140,70 +133,65 @@ struct DynamicBuffer
|
||||
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
|
||||
// ds_write_b128
|
||||
// TODO: remove this after compiler fix
|
||||
if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type,
|
||||
int8_t>::value)
|
||||
if constexpr(is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value)
|
||||
{
|
||||
static_assert(
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value) ||
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value) ||
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) ||
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value),
|
||||
"wrong! not implemented for this combination, please add "
|
||||
"implementation");
|
||||
static_assert((is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8_t>::value) ||
|
||||
(is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x2_t>::value) ||
|
||||
(is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x4_t>::value) ||
|
||||
(is_same<remove_cvref_t<T>, int8x4_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x4_t>::value) ||
|
||||
(is_same<remove_cvref_t<T>, int8x8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x8_t>::value) ||
|
||||
(is_same<remove_cvref_t<T>, int8x16_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x16_t>::value),
|
||||
"wrong! not implemented for this combination, please add "
|
||||
"implementation");
|
||||
|
||||
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value)
|
||||
if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int8_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int8_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value)
|
||||
else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x2_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int16_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int16_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
|
||||
else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x4_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
||||
int8x4_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
|
||||
else if constexpr(is_same<remove_cvref_t<T>, int8x4_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x4_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
||||
int8x8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value)
|
||||
else if constexpr(is_same<remove_cvref_t<T>, int8x8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x8_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32x2_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
||||
int8x16_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value)
|
||||
else if constexpr(is_same<remove_cvref_t<T>, int8x16_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x16_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
@@ -227,6 +215,35 @@ struct DynamicBuffer
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X need to be multiple T");
|
||||
|
||||
static_assert(GetAddressSpace() == AddressSpaceEnum_t::Global, "only support global mem");
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, element_space_size_);
|
||||
#else
|
||||
if(is_valid_element)
|
||||
{
|
||||
atomicAdd(&p_data_[i], x);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
|
||||
|
||||
@@ -114,12 +114,11 @@ struct MagicDivision
|
||||
__host__ __device__ static constexpr uint32_t
|
||||
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t tmp = (uint64_t(dividend) * uint64_t(multiplier)) >> 32;
|
||||
uint32_t tmp = __umulhi(dividend, multiplier);
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
|
||||
#if 1 // debug
|
||||
// HACK: magic division for int32_t
|
||||
// magic division for int32_t
|
||||
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
|
||||
// non-negative for result to be correct
|
||||
// TODO: figure out how to do magic number divison for int32_t as dividended
|
||||
@@ -127,27 +126,9 @@ struct MagicDivision
|
||||
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = as_type<uint32_t>(dividend_i32);
|
||||
uint32_t tmp =
|
||||
(static_cast<uint64_t>(dividend_u32) * static_cast<uint64_t>(multiplier)) >> 32;
|
||||
uint32_t tmp = __umulhi(dividend_u32, multiplier);
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
#else
|
||||
// the inline ASM is producing wrong result
|
||||
__host__ __device__ static int32_t
|
||||
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t r;
|
||||
asm volatile("\n \
|
||||
v_mul_hi_u32 %0, %1, %2 \n \
|
||||
v_add_u32_e32 %0, %1, %0 \n \
|
||||
v_lshrrev_b32_e32 %0, %3, %0 \n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(as_type<uint32_t>(dividend_i32)), "s"(multiplier), "s"(shift));
|
||||
|
||||
return as_type<int32_t>(r);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -55,6 +55,98 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
|
||||
};
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace,
|
||||
typename T,
|
||||
index_t N,
|
||||
bool InvalidElementUseNumericalZeroValue>
|
||||
struct StaticBufferV2 : public StaticallyIndexedArray<T, N>
|
||||
{
|
||||
using type = T;
|
||||
using base = StaticallyIndexedArray<T, N>;
|
||||
|
||||
using VecBaseType = typename T::d1_t;
|
||||
|
||||
__host__ __device__ static constexpr index_t GetVectorSize()
|
||||
{
|
||||
return sizeof(typename T::type) / sizeof(VecBaseType);
|
||||
}
|
||||
|
||||
static constexpr index_t vector_size = GetVectorSize();
|
||||
|
||||
VecBaseType invalid_element_value_ = VecBaseType{0};
|
||||
|
||||
T invalid_vec_value_ = T{0};
|
||||
|
||||
__host__ __device__ constexpr StaticBufferV2() : base{} {}
|
||||
|
||||
__host__ __device__ constexpr StaticBufferV2(VecBaseType invalid_element_value)
|
||||
: base{},
|
||||
invalid_vec_value_{invalid_element_value},
|
||||
invalid_element_value_{invalid_element_value}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
||||
{
|
||||
return BufferAddressSpace;
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& GetVector(Number<I> vec_id)
|
||||
{
|
||||
return this->At(vec_id);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const auto& GetVector(Number<I> vec_id) const
|
||||
{
|
||||
return this->At(vec_id);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& GetElement(Number<I> i, bool)
|
||||
{
|
||||
constexpr auto vec_id = Number<i / vector_size>{};
|
||||
constexpr auto vec_off = Number<i % vector_size>{};
|
||||
|
||||
return this->At(vec_id).template AsType<VecBaseType>()(vec_off);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto GetElement(Number<I> i, bool is_valid_element) const
|
||||
{
|
||||
constexpr auto vec_id = Number<i / vector_size>{};
|
||||
constexpr auto vec_off = Number<i % vector_size>{};
|
||||
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return is_valid_element ? this->At(vec_id).template AsType<VecBaseType>()[vec_off]
|
||||
: VecBaseType{0};
|
||||
}
|
||||
else
|
||||
{
|
||||
return is_valid_element ? this->At(vec_id).template AsType<VecBaseType>()[vec_off]
|
||||
: invalid_element_value_;
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto operator[](Number<I> i) const
|
||||
{
|
||||
return GetElement(i, true);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& operator()(Number<I> i)
|
||||
{
|
||||
return GetElement(i, true);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
|
||||
};
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
|
||||
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
|
||||
{
|
||||
|
||||
@@ -159,7 +159,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
|
||||
template <typename... Xs>
|
||||
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
|
||||
{
|
||||
return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...);
|
||||
return Tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -14,9 +14,7 @@ struct is_known_at_compile_time<Tuple<Ts...>>
|
||||
return container_reduce(
|
||||
Tuple<Ts...>{},
|
||||
[](auto x, bool r) {
|
||||
return is_known_at_compile_time<
|
||||
remove_cv_t<remove_reference_t<decltype(x)>>>::value &
|
||||
r;
|
||||
return is_known_at_compile_time<remove_cvref_t<decltype(x)>>::value & r;
|
||||
},
|
||||
true);
|
||||
}
|
||||
|
||||
@@ -374,13 +374,8 @@ extern "C" __global__ void
|
||||
CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{},
|
||||
CGridBlockCluster_BlockId_To_GM10_GN10{}));
|
||||
|
||||
const auto desc_tuple = *reinterpret_cast<const DescTuple*>(
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
// TODO: how to cast?
|
||||
(const void*)p_desc_tuple
|
||||
#pragma clang diagnostic pop
|
||||
);
|
||||
const auto desc_tuple =
|
||||
*reinterpret_cast<const DescTuple*>(cast_pointer_to_generic_address_space(p_desc_tuple));
|
||||
|
||||
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0];
|
||||
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1];
|
||||
|
||||
@@ -13,9 +13,15 @@ include_directories(BEFORE
|
||||
|
||||
set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp)
|
||||
set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp)
|
||||
set(CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp)
|
||||
set(GEMM_DRIVER_OFFLINE_SOURCE src/gemm_driver_offline.cpp)
|
||||
|
||||
add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE})
|
||||
add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE})
|
||||
add_executable(conv_wrw_driver_offline ${CONV_WRW_DRIVER_OFFLINE_SOURCE})
|
||||
add_executable(gemm_driver_offline ${GEMM_DRIVER_OFFLINE_SOURCE})
|
||||
|
||||
target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor)
|
||||
target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor)
|
||||
target_link_libraries(conv_wrw_driver_offline PRIVATE host_tensor)
|
||||
target_link_libraries(gemm_driver_offline PRIVATE host_tensor)
|
||||
|
||||
13
host/driver_offline/include/debug.hpp
Normal file
13
host/driver_offline/include/debug.hpp
Normal file
@@ -0,0 +1,13 @@
|
||||
#ifndef DEBUG_HPP
|
||||
#define DEBUG_HPP
|
||||
|
||||
namespace debug {
|
||||
namespace debug_driver_gemm_xdlops_v2r3 {
|
||||
|
||||
// these vars are on host, they control block_id to C matrix tile idx (m0, n0) mapping
|
||||
static ck::index_t M01 = 1;
|
||||
static ck::index_t N01 = 1;
|
||||
|
||||
} // namespace debug_driver_gemm_xdlops_v2r3
|
||||
} // namespace debug
|
||||
#endif
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
#include "debug.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
@@ -48,17 +49,17 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
@@ -76,7 +77,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -84,9 +85,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
@@ -105,16 +106,16 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
@@ -133,16 +134,16 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
@@ -159,34 +160,6 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
@@ -208,40 +181,42 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: GemmN
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple(
|
||||
// clang-format off
|
||||
constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 3+: NWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 7+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: NRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 3-: NWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
//clang-format on
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
|
||||
@@ -263,8 +238,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmMPerXDL,
|
||||
GemmNPerXDL,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
@@ -289,19 +264,23 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(in_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
false, // ABlockLdsExtraM
|
||||
false // BBlockLdsExtraN
|
||||
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
out_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
out_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
in_m0_m1_m2_n_grid_step_hacks,
|
||||
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
@@ -49,7 +49,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
@@ -77,7 +77,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
@@ -104,8 +104,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
@@ -133,7 +133,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
@@ -159,25 +159,93 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(out_n_ho_wo_k_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
I0,
|
||||
I0,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
@@ -185,7 +253,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-:
|
||||
// gemmk1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
@@ -195,25 +264,27 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
|
||||
|
||||
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple(
|
||||
// clang-format off
|
||||
constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 2+: MWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 4+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 5+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 6+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 2-: MWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 4-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 5-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
// clang-format on
|
||||
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
|
||||
@@ -223,64 +294,110 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(in_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<2, 0, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
for(index_t i_ytilda = 0; i_ytilda < YTilda; ++i_ytilda)
|
||||
{
|
||||
for(index_t i_xtilda = 0; i_xtilda < XTilda; ++i_xtilda)
|
||||
{
|
||||
const auto descs =
|
||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
out_n_ho_wo_k_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
i_ytilda,
|
||||
i_xtilda,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
const auto GemmK0 = out_gemmk0_gemmm_gemmk1_grid_desc.GetLength(I0);
|
||||
|
||||
if(GemmK0 != 0)
|
||||
{
|
||||
ave_time += driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(in_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<2, 0, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
#if 0
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
#else
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
#endif
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(in_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
true // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc,
|
||||
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
in_m0_m1_m2_n_grid_step_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
true, // CAccessOrderMRepeatNRepeat
|
||||
false, // ABlockLdsExtraM
|
||||
false // BBlockLdsExtraN
|
||||
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
|
||||
@@ -0,0 +1,389 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations&,
|
||||
const InLeftPads&,
|
||||
const InRightPads&,
|
||||
Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
const Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0>{}, // 1+: gemmm
|
||||
Sequence<0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: gemmk0
|
||||
Sequence<0, 0, 0>{}, // 1-: gemmm
|
||||
Sequence<0, 0, 0>{})); // 2-: gemmk1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0>{}, // 1+: gemmn
|
||||
Sequence<0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: Gemmk0
|
||||
Sequence<0, 0, 0>{}, // 1-: Gemmn
|
||||
Sequence<0, 0, 0>{})); // 2-: Gemmk1
|
||||
|
||||
// clang-format off
|
||||
constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
// clang-format on
|
||||
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
const auto descs = transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1(
|
||||
out_n_ho_wo_k_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
conv_strides,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(in_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<2, 0, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
#if 0
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
#else
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
#endif
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
true, // CAccessOrderMRepeatNRepeat
|
||||
false, // ABlockLdsExtraM
|
||||
false // BBlockLdsExtraN
|
||||
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,258 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_gemm_xdlops_v2r4.hpp"
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
typename GridSizeType>
|
||||
void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TIn>& in_n_c_hi_wi,
|
||||
Tensor<TWei>& wei_k_c_y_x,
|
||||
const Tensor<TOut>& out_n_k_ho_wo,
|
||||
GridSizeType desired_grid_size,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TIn) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 64, 1>;
|
||||
// using vector load 4, so config's wo*ho must be a multiple of 4
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto N = in_n_c_hi_wi_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_desc.GetLength(I1);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_desc.GetLength(I3);
|
||||
|
||||
const auto GemmM = K;
|
||||
const auto GemmN = Y * X * C;
|
||||
const auto GemmKTotal = N * Ho * Wo;
|
||||
|
||||
const auto GemmK = GemmKTotal / GemmK1;
|
||||
|
||||
const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock);
|
||||
const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1);
|
||||
const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch));
|
||||
const index_t GemmK0 = BatchLen * GemmKPerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
|
||||
|
||||
std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN
|
||||
<< " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad
|
||||
<< std::endl;
|
||||
const auto descs =
|
||||
transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{},
|
||||
GemmKBatch,
|
||||
GemmKPad);
|
||||
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmB
|
||||
Sequence<0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmM
|
||||
Sequence<0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemB
|
||||
Sequence<0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmM
|
||||
Sequence<0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmB
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmN
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmB
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||
|
||||
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 1, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{};
|
||||
|
||||
const auto driver_gemm_xdlops =
|
||||
driver_gemm_xdlops_v2r4<BlockSize,
|
||||
TIn,
|
||||
TAcc,
|
||||
TWei,
|
||||
InMemoryDataOperationEnum_t::AtomicAdd,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(wei_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
3,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
3,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false,
|
||||
true,
|
||||
true>;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TIn*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>(calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
driver_gemm_xdlops(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TIn*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
0);
|
||||
// copy result back to host
|
||||
wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,234 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TIn>& in_n_c_hi_wi,
|
||||
Tensor<TWei>& wei_k_c_y_x,
|
||||
const Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TIn) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
// using vector load 4, so config's wo*ho must be a multiple of 4
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
// using vector load 4, so config's wo*ho must be a multiple of 4
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM
|
||||
Sequence<0, 0, 1, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 2, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM
|
||||
Sequence<0, 0, 2, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmN
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 1, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TIn,
|
||||
TAcc,
|
||||
TWei,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(wei_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true, // ABlockLdsExtraM
|
||||
true // BBlockLdsExtraN
|
||||
>(static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TIn*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>(calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,290 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r4.hpp"
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
typename GridSizeType>
|
||||
void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TIn>& in_n_hi_wi_c,
|
||||
Tensor<TWei>& wei_k_y_x_c,
|
||||
const Tensor<TOut>& out_n_ho_wo_k,
|
||||
GridSizeType desired_grid_size,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto N = in_n_hi_wi_c_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_desc.GetLength(I2);
|
||||
|
||||
const auto GemmM = Y * X * C;
|
||||
const auto GemmN = K;
|
||||
const auto GemmKTotal = N * Ho * Wo;
|
||||
|
||||
const auto GemmK = GemmKTotal / GemmK1;
|
||||
|
||||
const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock);
|
||||
const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1);
|
||||
const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch));
|
||||
const index_t GemmK0 = BatchLen * GemmKPerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
|
||||
|
||||
std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN
|
||||
<< " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad
|
||||
<< std::endl;
|
||||
|
||||
const auto descs =
|
||||
transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad(
|
||||
in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{},
|
||||
GemmKBatch,
|
||||
GemmKPad);
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmKBatch
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmKBatch
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||
|
||||
constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
const auto driver_gemm_xdlops = driver_gemm_xdlops_v2r4<
|
||||
BlockSize,
|
||||
TIn,
|
||||
TAcc,
|
||||
TWei,
|
||||
InMemoryDataOperationEnum_t::AtomicAdd,
|
||||
decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(wei_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerXDL,
|
||||
GemmNPerXDL,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmM,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true,
|
||||
true>;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops(static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
driver_gemm_xdlops(static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
0);
|
||||
// copy result back to host
|
||||
wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,276 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
#include "debug.hpp"
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TIn>& in_n_hi_wi_c,
|
||||
Tensor<TWei>& wei_k_y_x_c,
|
||||
const Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
||||
in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{};
|
||||
|
||||
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TIn,
|
||||
TAcc,
|
||||
TWei,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(wei_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerXDL,
|
||||
GemmNPerXDL,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmM,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true,
|
||||
true>(static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
out_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
in_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
out_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,458 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r4.hpp"
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
typename GridSizeType>
|
||||
void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TIn>& in_n_hi_wi_c,
|
||||
Tensor<TWei>& wei_k_y_x_c,
|
||||
const Tensor<TOut>& out_n_ho_wo_k,
|
||||
GridSizeType desired_grid_size,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4], C 128, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4], C 128, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C 64, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C 64, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 64;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 16, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [64, 64, 4, 8], C 32, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 64;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto N = in_n_hi_wi_c_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_desc.GetLength(I2);
|
||||
|
||||
const auto GemmM = K;
|
||||
const auto GemmN = Y * X * C;
|
||||
const auto GemmKTotal = N * Ho * Wo;
|
||||
|
||||
const auto GemmK = GemmKTotal / GemmK1;
|
||||
|
||||
const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock);
|
||||
const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1);
|
||||
const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch));
|
||||
const index_t GemmK0 = BatchLen * GemmKPerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1;
|
||||
|
||||
std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN
|
||||
<< " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad
|
||||
<< std::endl;
|
||||
|
||||
const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad(
|
||||
in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{},
|
||||
GemmKBatch,
|
||||
GemmKPad);
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{};
|
||||
|
||||
const auto driver_gemm_xdlops = driver_gemm_xdlops_v2r4<
|
||||
BlockSize,
|
||||
TIn,
|
||||
TAcc,
|
||||
TWei,
|
||||
InMemoryDataOperationEnum_t::AtomicAdd,
|
||||
decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(wei_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerXDL,
|
||||
GemmNPerXDL,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmM,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 3, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true,
|
||||
true>;
|
||||
|
||||
// timing
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// verification
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
driver_gemm_xdlops(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TIn*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
0);
|
||||
// copy result back to host
|
||||
wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data());
|
||||
}
|
||||
@@ -1,280 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 4;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
#if 1
|
||||
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad
|
||||
#else
|
||||
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1
|
||||
#endif
|
||||
<TInWei, GemmMPerBlock, GemmNPerBlock, GemmMPerWave, GemmNPerWave, GemmKPack>(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
#if 0
|
||||
float ave_time = launch_kernel_gemm_xdlops_v1
|
||||
#else
|
||||
float ave_time = launch_kernel_gemm_xdlops_v2
|
||||
#endif
|
||||
<BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(descs[I0]),
|
||||
decltype(descs[I1]),
|
||||
decltype(descs[I2]),
|
||||
decltype(descs[I3]),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmKPack,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmABlockTransferDstScalarPerVector_KPack,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<1, 0, 2>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_KPack,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused
|
||||
// with MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<2, 3, 0, 1>,
|
||||
3,
|
||||
GemmCThreadTransferDstScalarPerVector_GemmN1,
|
||||
decltype(descs[I4]),
|
||||
decltype(descs[I5]),
|
||||
decltype(descs[I6]),
|
||||
decltype(descs[I7]),
|
||||
decltype(descs[I8])>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
descs[I0],
|
||||
descs[I1],
|
||||
descs[I2],
|
||||
descs[I3],
|
||||
descs[I4],
|
||||
descs[I5],
|
||||
descs[I6],
|
||||
descs[I7],
|
||||
descs[I8],
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -47,7 +47,35 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -92,36 +120,39 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmN
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
@@ -169,7 +200,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
@@ -180,7 +211,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
out_m0_m1_m2_n_grid_step_hacks,
|
||||
out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
@@ -1,229 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r2.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r2<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1>,
|
||||
2,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
out_m0_m1_m2_n_grid_step_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
@@ -1,302 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
out_m0_m1_m2_n_grid_step_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
@@ -49,15 +49,15 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
@@ -77,16 +77,16 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 128, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
@@ -105,16 +105,16 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
|
||||
// [M, N, K0, K1] = [256, 256, 4, 8], C = 256, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 4;
|
||||
@@ -133,16 +133,16 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
@@ -161,16 +161,16 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
@@ -188,17 +188,17 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
@@ -215,6 +215,62 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerXDL = 32;
|
||||
constexpr index_t GemmNPerXDL = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
@@ -249,23 +305,23 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 3+: NWaves
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 4+: M0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 5+: M1
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 6+: M2
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 7+: N1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: NRepeat
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 2-: MWaves
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 3-: NWaves
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 4-: M0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 5-: M1
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 6-: M2
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 7-: N1
|
||||
constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
@@ -287,8 +343,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmMPerXDL,
|
||||
GemmNPerXDL,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
@@ -313,19 +369,23 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true, // ABlockLdsExtraM
|
||||
true // BBlockLdsExtraN
|
||||
>(static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
in_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
out_m0_m1_m2_n_grid_step_hacks,
|
||||
out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
463
host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp
Normal file
463
host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp
Normal file
@@ -0,0 +1,463 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_km_kn_mn(const Tensor<ABType>& a_k_m,
|
||||
const Tensor<ABType>& b_k_n,
|
||||
Tensor<CType>& c_m_n,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
|
||||
|
||||
a_k_m_device_buf.ToDevice(a_k_m.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_m_n_device_buf.ToDevice(c_m_n.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4], C = 128, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 1;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 1;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto K = a_k_m.mDesc.GetLengths()[0];
|
||||
const auto M = a_k_m.mDesc.GetLengths()[1];
|
||||
const auto N = b_k_n.mDesc.GetLengths()[1];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0],
|
||||
a_k_m.mDesc.GetStrides()[1],
|
||||
a_k_m.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0],
|
||||
b_k_n.mDesc.GetStrides()[1],
|
||||
b_k_n.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector_M,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector_N,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
7,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true, // ABlockLdsExtraM
|
||||
true // BBlockLdsExtraN
|
||||
>(static_cast<ABType*>(a_k_m_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
|
||||
}
|
||||
263
host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp
Normal file
263
host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp
Normal file
@@ -0,0 +1,263 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_km_kn_nm(const Tensor<ABType>& a_k_m,
|
||||
const Tensor<ABType>& b_k_n,
|
||||
Tensor<CType>& c_n_m,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace());
|
||||
|
||||
a_k_m_device_buf.ToDevice(a_k_m.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_n_m_device_buf.ToDevice(c_n_m.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto K = a_k_m.mDesc.GetLengths()[0];
|
||||
const auto M = a_k_m.mDesc.GetLengths()[1];
|
||||
const auto N = b_k_n.mDesc.GetLengths()[1];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0],
|
||||
a_k_m.mDesc.GetStrides()[1],
|
||||
a_k_m.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0],
|
||||
b_k_n.mDesc.GetStrides()[1],
|
||||
b_k_n.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector_M,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector_N,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<ABType*>(a_k_m_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_n_m_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_n_m_device_buf.FromDevice(c_n_m.mData.data());
|
||||
}
|
||||
463
host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp
Normal file
463
host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp
Normal file
@@ -0,0 +1,463 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_km_nk_mn(const Tensor<ABType>& a_k_m,
|
||||
const Tensor<ABType>& b_n_k,
|
||||
Tensor<CType>& c_m_n,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace());
|
||||
DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
|
||||
|
||||
a_k_m_device_buf.ToDevice(a_k_m.mData.data());
|
||||
b_n_k_device_buf.ToDevice(b_n_k.mData.data());
|
||||
c_m_n_device_buf.ToDevice(c_m_n.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto K = a_k_m.mDesc.GetLengths()[0];
|
||||
const auto M = a_k_m.mDesc.GetLengths()[1];
|
||||
const auto N = b_n_k.mDesc.GetLengths()[0];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0],
|
||||
a_k_m.mDesc.GetStrides()[1],
|
||||
a_k_m.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
|
||||
b_n_k.mDesc.GetStrides()[0],
|
||||
b_n_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector_M,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector_K1,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
7,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true, // ABlockLdsExtraM
|
||||
true // BBlockLdsExtraN
|
||||
>(static_cast<ABType*>(a_k_m_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_n_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
|
||||
}
|
||||
263
host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp
Normal file
263
host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp
Normal file
@@ -0,0 +1,263 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_km_nk_nm(const Tensor<ABType>& a_k_m,
|
||||
const Tensor<ABType>& b_n_k,
|
||||
Tensor<CType>& c_n_m,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace());
|
||||
DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace());
|
||||
DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace());
|
||||
|
||||
a_k_m_device_buf.ToDevice(a_k_m.mData.data());
|
||||
b_n_k_device_buf.ToDevice(b_n_k.mData.data());
|
||||
c_n_m_device_buf.ToDevice(c_n_m.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto K = a_k_m.mDesc.GetLengths()[0];
|
||||
const auto M = a_k_m.mDesc.GetLengths()[1];
|
||||
const auto N = b_n_k.mDesc.GetLengths()[0];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0],
|
||||
a_k_m.mDesc.GetStrides()[1],
|
||||
a_k_m.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
|
||||
b_n_k.mDesc.GetStrides()[0],
|
||||
b_n_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector_M,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector_K1,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<ABType*>(a_k_m_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_n_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_n_m_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_n_m_device_buf.FromDevice(c_n_m.mData.data());
|
||||
}
|
||||
463
host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
Normal file
463
host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
Normal file
@@ -0,0 +1,463 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
|
||||
const Tensor<ABType>& b_k_n,
|
||||
Tensor<CType>& c_m_n,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_m_n_device_buf.ToDevice(c_m_n.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 1;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 1;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto K = a_m_k.mDesc.GetLengths()[1];
|
||||
const auto M = a_m_k.mDesc.GetLengths()[0];
|
||||
const auto N = b_k_n.mDesc.GetLengths()[1];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
|
||||
a_m_k.mDesc.GetStrides()[0],
|
||||
a_m_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0],
|
||||
b_k_n.mDesc.GetStrides()[1],
|
||||
b_k_n.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector_K1,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector_N,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
7,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true, // ABlockLdsExtraM
|
||||
true // BBlockLdsExtraN
|
||||
>(static_cast<ABType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
|
||||
}
|
||||
291
host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp
Normal file
291
host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp
Normal file
@@ -0,0 +1,291 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_mk_kn_nm(const Tensor<ABType>& a_m_k,
|
||||
const Tensor<ABType>& b_k_n,
|
||||
Tensor<CType>& c_n_m,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_n_m_device_buf.ToDevice(c_n_m.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_N = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto K = a_m_k.mDesc.GetLengths()[1];
|
||||
const auto M = a_m_k.mDesc.GetLengths()[0];
|
||||
const auto N = b_k_n.mDesc.GetLengths()[1];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
|
||||
a_m_k.mDesc.GetStrides()[0],
|
||||
a_m_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0],
|
||||
b_k_n.mDesc.GetStrides()[1],
|
||||
b_k_n.mDesc.GetStrides()[0]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector_K1,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector_N,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<ABType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_n_m_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_n_m_device_buf.FromDevice(c_n_m.mData.data());
|
||||
}
|
||||
564
host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp
Normal file
564
host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp
Normal file
@@ -0,0 +1,564 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_mk_nk_mn(const Tensor<ABType>& a_m_k,
|
||||
const Tensor<ABType>& b_n_k,
|
||||
Tensor<CType>& c_m_n,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_n_k_device_buf.ToDevice(b_n_k.mData.data());
|
||||
c_m_n_device_buf.ToDevice(c_m_n.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 64;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto K = a_m_k.mDesc.GetLengths()[1];
|
||||
const auto M = a_m_k.mDesc.GetLengths()[0];
|
||||
const auto N = b_n_k.mDesc.GetLengths()[0];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
#if 1
|
||||
// non-padded GEMM
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
|
||||
a_m_k.mDesc.GetStrides()[0],
|
||||
a_m_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
|
||||
b_n_k.mDesc.GetStrides()[0],
|
||||
b_n_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
#else
|
||||
// padded GEMM
|
||||
const auto a_k0_m_k1_grid_desc_tmp =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
|
||||
a_m_k.mDesc.GetStrides()[0],
|
||||
a_m_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto MRightPad = math::integer_divide_ceil(M, MPerBlock) * MPerBlock - M;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
transform_tensor_descriptor(a_k0_m_k1_grid_desc_tmp,
|
||||
make_tuple(make_pass_through_transform(K0),
|
||||
make_right_pad_transform(M, MRightPad),
|
||||
make_pass_through_transform(K1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
|
||||
b_n_k.mDesc.GetStrides()[0],
|
||||
b_n_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto c_m_n_grid_desc_tmp = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto c_m_n_grid_desc = transform_tensor_descriptor(
|
||||
c_m_n_grid_desc_tmp,
|
||||
make_tuple(make_right_pad_transform(M, MRightPad), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0, 0>{}, // 1+: M
|
||||
Sequence<0, 0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0, 0>{}, // 1-: M
|
||||
Sequence<0, 0, 0, 0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
#endif
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector_K1,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector_K1,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
7,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false, // CAccessOrderMRepeatNRepeat
|
||||
true, // ABlockLdsExtraM
|
||||
true // BBlockLdsExtraN
|
||||
>(static_cast<ABType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_n_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01,
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
|
||||
}
|
||||
347
host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp
Normal file
347
host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp
Normal file
@@ -0,0 +1,347 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename ABType, typename AccType, typename CType>
|
||||
void device_gemm_xdlops_mk_nk_nm(const Tensor<ABType>& a_m_k,
|
||||
const Tensor<ABType>& b_n_k,
|
||||
Tensor<CType>& c_n_m,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace());
|
||||
DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_n_k_device_buf.ToDevice(b_n_k.mData.data());
|
||||
c_n_m_device_buf.ToDevice(c_n_m.mData.data());
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 256;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 256;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 128;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t MPerBlock = 64;
|
||||
constexpr index_t NPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 4;
|
||||
|
||||
constexpr index_t MPerXDL = 32;
|
||||
constexpr index_t NPerXDL = 32;
|
||||
constexpr index_t K1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto K = a_m_k.mDesc.GetLengths()[1];
|
||||
const auto M = a_m_k.mDesc.GetLengths()[0];
|
||||
const auto N = b_n_k.mDesc.GetLengths()[0];
|
||||
|
||||
constexpr auto K1Number = Number<K1>{};
|
||||
const auto K0 = K / K1Number;
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
|
||||
make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1],
|
||||
a_m_k.mDesc.GetStrides()[0],
|
||||
a_m_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, N, K1Number),
|
||||
make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1],
|
||||
b_n_k.mDesc.GetStrides()[0],
|
||||
b_n_k.mDesc.GetStrides()[1]));
|
||||
|
||||
const auto c_m_n_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0]));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: M
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: M
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0
|
||||
Sequence<0>{}, // 1+: N
|
||||
Sequence<0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0>{}, // 0-: K0
|
||||
Sequence<0>{}, // 1-: N
|
||||
Sequence<0>{})); // 2-: K1
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time =
|
||||
driver_gemm_xdlops_v2r3<BlockSize,
|
||||
ABType,
|
||||
AccType,
|
||||
CType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(c_m_n_grid_desc),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector_K1,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector_K1,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_step_hacks),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks),
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<ABType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ABType*>(b_n_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CType*>(c_n_m_device_buf.GetDeviceBuffer()),
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m_n_grid_desc,
|
||||
a_k0_m_k1_grid_step_hacks,
|
||||
b_k0_n_k1_grid_step_hacks,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hacks,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * M * N * K)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
c_n_m_device_buf.FromDevice(c_n_m.mData.data());
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
#ifndef DRIVER_GEMM_XDLOPS_V2R3
|
||||
#define DRIVER_GEMM_XDLOPS_V2R3
|
||||
#ifndef DRIVER_GEMM_XDLOPS_V2R3_HPP
|
||||
#define DRIVER_GEMM_XDLOPS_V2R3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
@@ -17,8 +17,8 @@ template <ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t MPerWave,
|
||||
ck::index_t NPerWave,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t K1,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
@@ -46,13 +46,17 @@ template <ck::index_t BlockSize,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks,
|
||||
bool CAccessOrderMRepeatNRepeat>
|
||||
bool CAccessOrderMRepeatNRepeat,
|
||||
bool ABlockLdsAddExtraM,
|
||||
bool BBlockLdsAddExtraN>
|
||||
__host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
@@ -79,8 +83,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
@@ -108,7 +112,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
CAccessOrderMRepeatNRepeat>;
|
||||
CAccessOrderMRepeatNRepeat,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockLdsAddExtraN>;
|
||||
|
||||
{
|
||||
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
|
||||
@@ -123,68 +129,146 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
|
||||
if(!GridwiseGemm::CheckValidity(
|
||||
a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, M01, N01))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
|
||||
}
|
||||
|
||||
const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
|
||||
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc);
|
||||
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
|
||||
|
||||
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
const auto c_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, M01, N01);
|
||||
|
||||
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
|
||||
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0MK1GridDesc>,
|
||||
remove_reference_t<BK0NK1GridDesc>,
|
||||
remove_reference_t<CM0M1M2NGridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>>;
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0MK1GridDesc>,
|
||||
remove_reference_t<BK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0MK1GridDesc>,
|
||||
remove_reference_t<BK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
|
||||
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
|
||||
DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc));
|
||||
DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc));
|
||||
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
|
||||
|
||||
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
|
||||
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
|
||||
c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc);
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
|
||||
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
|
||||
|
||||
float ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0MK1GridDesc>,
|
||||
remove_reference_t<BK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0MK1GridDesc>,
|
||||
remove_reference_t<BK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
209
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
Normal file
209
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
Normal file
@@ -0,0 +1,209 @@
|
||||
#ifndef DRIVER_GEMM_XDLOPS_V2R4
|
||||
#define DRIVER_GEMM_XDLOPS_V2R4
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r4.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename ABK0MK1GridDesc,
|
||||
typename BBK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t K1,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks,
|
||||
bool CAccessOrderMRepeatNRepeat,
|
||||
bool ABlockLdsAddExtraM,
|
||||
bool BBlockLdsAddExtraN>
|
||||
__host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
|
||||
const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
ck::index_t nrepeat)
|
||||
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
ABK0MK1GridDesc,
|
||||
BBK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
CAccessOrderMRepeatNRepeat,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockLdsAddExtraN>;
|
||||
|
||||
{
|
||||
std::cout << "a_b_k0_m_k1_grid_desc{" << a_b_k0_m_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_b_k0_m_k1_grid_desc.GetLength(I1) << ", "
|
||||
<< a_b_k0_m_k1_grid_desc.GetLength(I2) << ", "
|
||||
<< a_b_k0_m_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "b_b_k0_n_k1_grid_desc{" << b_b_k0_n_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_b_k0_n_k1_grid_desc.GetLength(I1) << ", "
|
||||
<< b_b_k0_n_k1_grid_desc.GetLength(I2) << ", "
|
||||
<< b_b_k0_n_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(
|
||||
a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_m_n_grid_desc, M01, N01))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r4 has invalid setting");
|
||||
}
|
||||
|
||||
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
|
||||
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
|
||||
|
||||
const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
const auto c_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, M01, N01, KBatch);
|
||||
|
||||
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc, KBatch);
|
||||
{
|
||||
std::cout << "gridSize : " << grid_size << std::endl;
|
||||
}
|
||||
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<ABK0MK1GridDesc>,
|
||||
remove_reference_t<BBK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>>;
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_b_k0_m_k1_grid_desc_dev_buf(sizeof(ABK0MK1GridDesc));
|
||||
DeviceMem b_b_k0_n_k1_grid_desc_dev_buf(sizeof(BBK0NK1GridDesc));
|
||||
DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc));
|
||||
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
|
||||
|
||||
a_b_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_b_k0_m_k1_grid_desc);
|
||||
b_b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_b_k0_n_k1_grid_desc);
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
|
||||
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
|
||||
|
||||
float ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_b_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
#endif
|
||||
return ave_time;
|
||||
}
|
||||
#endif
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "debug.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
@@ -14,15 +15,16 @@
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp"
|
||||
|
||||
#define USE_MODE 1
|
||||
#define USE_CONV_BWD_V4R1_XDL_NHWC 1
|
||||
#define USE_CONV_BWD_V4R1_XDL_NHWC 0
|
||||
#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1
|
||||
|
||||
enum ConvBackwardDataAlgo
|
||||
{
|
||||
V4R1XDLNHWC,
|
||||
V4R1R2XDLNHWC,
|
||||
V4R1XDLNHWC, // 0
|
||||
V4R1R2XDLNHWC, // 1
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -41,7 +43,7 @@ int main(int argc, char* argv[])
|
||||
// dynamic mode
|
||||
if(argc != 22)
|
||||
{
|
||||
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n");
|
||||
exit(1);
|
||||
}
|
||||
@@ -79,7 +81,7 @@ int main(int argc, char* argv[])
|
||||
// static mode
|
||||
if(argc < 7)
|
||||
{
|
||||
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
@@ -90,28 +92,28 @@ int main(int argc, char* argv[])
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t Hi = 71;
|
||||
constexpr index_t Wi = 71;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
constexpr auto N = Number<128>{};
|
||||
constexpr auto C = Number<192>{};
|
||||
constexpr auto Hi = Number<71>{};
|
||||
constexpr auto Wi = Number<71>{};
|
||||
constexpr auto K = Number<256>{};
|
||||
constexpr auto Y = Number<3>{};
|
||||
constexpr auto X = Number<3>{};
|
||||
|
||||
const index_t conv_stride_h = 2;
|
||||
const index_t conv_stride_w = 2;
|
||||
const index_t conv_dilation_h = 1;
|
||||
const index_t conv_dilation_w = 1;
|
||||
const index_t in_left_pad_h = 1;
|
||||
const index_t in_left_pad_w = 1;
|
||||
const index_t in_right_pad_h = 1;
|
||||
const index_t in_right_pad_w = 1;
|
||||
constexpr auto conv_stride_h = I2;
|
||||
constexpr auto conv_stride_w = I2;
|
||||
constexpr auto conv_dilation_h = I1;
|
||||
constexpr auto conv_dilation_w = I1;
|
||||
constexpr auto in_left_pad_h = I1;
|
||||
constexpr auto in_left_pad_w = I1;
|
||||
constexpr auto in_right_pad_h = I1;
|
||||
constexpr auto in_right_pad_w = I1;
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
constexpr auto YEff = (Y - I1) * conv_dilation_h + I1;
|
||||
constexpr auto XEff = (X - I1) * conv_dilation_w + I1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1;
|
||||
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
@@ -119,9 +121,9 @@ int main(int argc, char* argv[])
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 1
|
||||
using in_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
using in_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
||||
@@ -280,20 +282,43 @@ int main(int argc, char* argv[])
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in_device,
|
||||
wei,
|
||||
out,
|
||||
nrepeat);
|
||||
if(Y == 1 && X == 1 && in_left_pad_h == 0 && in_left_pad_w == 0 && in_right_pad_h == 0 &&
|
||||
in_right_pad_w == 0)
|
||||
{
|
||||
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1<
|
||||
in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in_device,
|
||||
wei,
|
||||
out,
|
||||
nrepeat);
|
||||
}
|
||||
else
|
||||
{
|
||||
#if 1
|
||||
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in_device,
|
||||
wei,
|
||||
out,
|
||||
nrepeat);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "debug.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
@@ -19,13 +20,13 @@
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
#define USE_MODE 1
|
||||
#define USE_CONV_FWD_V4R4_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4R2_NHWC 1
|
||||
#define USE_DYNAMIC_MODE 1
|
||||
#define USE_CONV_FWD_V4R4_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4R2_NHWC 0
|
||||
#define USE_CONV_FWD_V6R1_NCHW 0
|
||||
#define USE_CONV_FWD_V5R1_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
|
||||
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1
|
||||
|
||||
enum ConvForwardAlgo
|
||||
{
|
||||
@@ -49,11 +50,11 @@ int main(int argc, char* argv[])
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
#if USE_MODE
|
||||
#if USE_DYNAMIC_MODE
|
||||
// dynamic mode
|
||||
if(argc != 22)
|
||||
{
|
||||
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n");
|
||||
exit(1);
|
||||
}
|
||||
@@ -91,7 +92,7 @@ int main(int argc, char* argv[])
|
||||
// static mode
|
||||
if(argc < 7)
|
||||
{
|
||||
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
@@ -102,38 +103,38 @@ int main(int argc, char* argv[])
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t Hi = 71;
|
||||
constexpr index_t Wi = 71;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
constexpr auto N = Number<128>{};
|
||||
constexpr auto C = Number<192>{};
|
||||
constexpr auto Hi = Number<71>{};
|
||||
constexpr auto Wi = Number<71>{};
|
||||
constexpr auto K = Number<256>{};
|
||||
constexpr auto Y = Number<3>{};
|
||||
constexpr auto X = Number<3>{};
|
||||
|
||||
const index_t conv_stride_h = 2;
|
||||
const index_t conv_stride_w = 2;
|
||||
const index_t conv_dilation_h = 1;
|
||||
const index_t conv_dilation_w = 1;
|
||||
const index_t in_left_pad_h = 1;
|
||||
const index_t in_left_pad_w = 1;
|
||||
const index_t in_right_pad_h = 1;
|
||||
const index_t in_right_pad_w = 1;
|
||||
constexpr auto conv_stride_h = I2;
|
||||
constexpr auto conv_stride_w = I2;
|
||||
constexpr auto conv_dilation_h = I1;
|
||||
constexpr auto conv_dilation_w = I1;
|
||||
constexpr auto in_left_pad_h = I1;
|
||||
constexpr auto in_left_pad_w = I1;
|
||||
constexpr auto in_right_pad_h = I1;
|
||||
constexpr auto in_right_pad_w = I1;
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
constexpr auto YEff = (Y - I1) * conv_dilation_h + I1;
|
||||
constexpr auto XEff = (X - I1) * conv_dilation_w + I1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1;
|
||||
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 1
|
||||
using in_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
using in_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
#elif 1
|
||||
using in_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
@@ -228,7 +229,6 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
auto f_make_for_device_nchw = [&]() {
|
||||
#if USE_MODE
|
||||
const auto in_lengths_dev = make_tuple(N, C, Hi, Wi);
|
||||
const auto wei_lengths_dev = make_tuple(K, C, Y, X);
|
||||
const auto out_lengths_dev = make_tuple(N, K, Ho, Wo);
|
||||
@@ -236,19 +236,6 @@ int main(int argc, char* argv[])
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
#else
|
||||
const auto in_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<C>{}, Number<Hi>{}, Number<Wi>{});
|
||||
const auto wei_lengths_dev = make_tuple(Number<K>{}, Number<C>{}, Number<Y>{}, Number<X>{});
|
||||
const auto out_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<K>{}, Number<Ho>{}, Number<Wo>{});
|
||||
const auto conv_strides_dev = make_tuple(Number<conv_stride_h>{}, Number<conv_stride_w>{});
|
||||
const auto conv_dilations_dev =
|
||||
make_tuple(Number<conv_dilation_h>{}, Number<conv_dilation_w>{});
|
||||
const auto in_left_pads_dev = make_tuple(Number<in_left_pad_h>{}, Number<in_left_pad_w>{});
|
||||
const auto in_right_pads_dev =
|
||||
make_tuple(Number<in_right_pad_h>{}, Number<in_right_pad_w>{});
|
||||
#endif
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
@@ -260,7 +247,6 @@ int main(int argc, char* argv[])
|
||||
};
|
||||
|
||||
auto f_make_for_device_nhwc = [&]() {
|
||||
#if USE_MODE
|
||||
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
|
||||
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
|
||||
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
|
||||
@@ -268,19 +254,6 @@ int main(int argc, char* argv[])
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
#else
|
||||
const auto in_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<Hi>{}, Number<Wi>{}, Number<C>{});
|
||||
const auto wei_lengths_dev = make_tuple(Number<K>{}, Number<Y>{}, Number<X>{}, Number<C>{});
|
||||
const auto out_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<Ho>{}, Number<Wo>{}, Number<K>{});
|
||||
const auto conv_strides_dev = make_tuple(Number<conv_stride_h>{}, Number<conv_stride_w>{});
|
||||
const auto conv_dilations_dev =
|
||||
make_tuple(Number<conv_dilation_h>{}, Number<conv_dilation_w>{});
|
||||
const auto in_left_pads_dev = make_tuple(Number<in_left_pad_h>{}, Number<in_left_pad_w>{});
|
||||
const auto in_right_pads_dev =
|
||||
make_tuple(Number<in_right_pad_h>{}, Number<in_right_pad_w>{});
|
||||
#endif
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
|
||||
436
host/driver_offline/src/conv_wrw_driver_offline.cpp
Normal file
436
host/driver_offline/src/conv_wrw_driver_offline.cpp
Normal file
@@ -0,0 +1,436 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "debug.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv_bwd_weight.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
#define USE_DYNAMIC_MODE 1
|
||||
#define USE_CONV_WRW_V4R4R2_XDL_NCHW 0
|
||||
#define USE_CONV_WRW_V4R4R4_XDL_NHWC 0
|
||||
#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 0
|
||||
#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 0
|
||||
#define USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC 1
|
||||
|
||||
enum ConvBackwardWeightAlgo
|
||||
{
|
||||
V4R4R2XDLNCHW, // 0
|
||||
V4R4R4XDLNHWC, // 1
|
||||
V4R4R2XDLATOMICNCHW, // 2
|
||||
V4R4R4XDLATOMICNHWC, // 3
|
||||
V4R4R5XDLATOMICNHWC, // 4
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
#if USE_DYNAMIC_MODE
|
||||
// dynamic mode
|
||||
if(argc != 23)
|
||||
{
|
||||
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n");
|
||||
printf("additional: desired_grid_size\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||
const ConvBackwardWeightAlgo algo = static_cast<ConvBackwardWeightAlgo>(std::stoi(argv[2]));
|
||||
const bool do_verification = std::stoi(argv[3]);
|
||||
const int init_method = std::stoi(argv[4]);
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
const index_t N = std::stoi(argv[7]);
|
||||
const index_t K = std::stoi(argv[8]);
|
||||
const index_t C = std::stoi(argv[9]);
|
||||
const index_t Y = std::stoi(argv[10]);
|
||||
const index_t X = std::stoi(argv[11]);
|
||||
const index_t Hi = std::stoi(argv[12]);
|
||||
const index_t Wi = std::stoi(argv[13]);
|
||||
|
||||
const index_t conv_stride_h = std::stoi(argv[14]);
|
||||
const index_t conv_stride_w = std::stoi(argv[15]);
|
||||
const index_t conv_dilation_h = std::stoi(argv[16]);
|
||||
const index_t conv_dilation_w = std::stoi(argv[17]);
|
||||
const index_t in_left_pad_h = std::stoi(argv[18]);
|
||||
const index_t in_left_pad_w = std::stoi(argv[19]);
|
||||
const index_t in_right_pad_h = std::stoi(argv[20]);
|
||||
const index_t in_right_pad_w = std::stoi(argv[21]);
|
||||
|
||||
const index_t desired_grid_size = std::stoi(argv[22]);
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
#else
|
||||
// static mode
|
||||
if(argc < 7)
|
||||
{
|
||||
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||
const ConvBackwardWeightAlgo algo = static_cast<ConvBackwardWeightAlgo>(std::stoi(argv[2]));
|
||||
const bool do_verification = std::stoi(argv[3]);
|
||||
const int init_method = std::stoi(argv[4]);
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
constexpr auto N = Number<128>{};
|
||||
constexpr auto C = Number<128>{};
|
||||
constexpr auto Hi = Number<14>{};
|
||||
constexpr auto Wi = Number<14>{};
|
||||
constexpr auto K = Number<256>{};
|
||||
constexpr auto Y = Number<3>{};
|
||||
constexpr auto X = Number<3>{};
|
||||
|
||||
constexpr auto conv_stride_h = I1;
|
||||
constexpr auto conv_stride_w = I1;
|
||||
constexpr auto conv_dilation_h = I1;
|
||||
constexpr auto conv_dilation_w = I1;
|
||||
constexpr auto in_left_pad_h = I1;
|
||||
constexpr auto in_left_pad_w = I1;
|
||||
constexpr auto in_right_pad_h = I1;
|
||||
constexpr auto in_right_pad_w = I1;
|
||||
|
||||
constexpr auto YEff = (Y - I1) * conv_dilation_h + I1;
|
||||
constexpr auto XEff = (X - I1) * conv_dilation_w + I1;
|
||||
|
||||
constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1;
|
||||
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
using in_data_t = float;
|
||||
using wei_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 1
|
||||
using in_data_t = half_t;
|
||||
using out_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using wei_data_t = float;
|
||||
#elif 1
|
||||
using in_data_t = int8_t;
|
||||
using out_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
using wei_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
||||
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(Wi);
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(X);
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(X);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Wo);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(K);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::runtime_error("wrong! not implemented");
|
||||
}
|
||||
|
||||
Tensor<in_data_t> in(in_lengths_host);
|
||||
Tensor<wei_data_t> wei_device(wei_lengths_host);
|
||||
Tensor<wei_data_t> wei_host(wei_lengths_host);
|
||||
Tensor<out_data_t> out(out_lengths_host);
|
||||
|
||||
std::cout << "layout: " << layout << std::endl;
|
||||
ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: ");
|
||||
ostream_HostTensorDescriptor(wei_host.mDesc, std::cout << "wei: ");
|
||||
ostream_HostTensorDescriptor(out.mDesc, std::cout << "out: ");
|
||||
print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w));
|
||||
print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w));
|
||||
print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w));
|
||||
print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w));
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
// no initialization
|
||||
break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 4:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 5:
|
||||
in.GenerateTensorValue(GeneratorTensor_3<float>{-0.1, 0.1}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_3<float>{-0.1, 0.1}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
|
||||
auto gen_out = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
out.GenerateTensorValue(gen_out, num_thread);
|
||||
}
|
||||
|
||||
auto f_make_for_device_nchw = [&]() {
|
||||
const auto in_lengths_dev = make_tuple(N, C, Hi, Wi);
|
||||
const auto wei_lengths_dev = make_tuple(K, C, Y, X);
|
||||
const auto out_lengths_dev = make_tuple(N, K, Ho, Wo);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
auto f_make_for_device_nhwc = [&]() {
|
||||
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
|
||||
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
|
||||
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
// set zero to wei_device
|
||||
wei_device.GenerateTensorValue(GeneratorTensor_0{}, num_thread);
|
||||
#if USE_CONV_WRW_V4R4R2_XDL_NCHW
|
||||
if(algo == ConvBackwardWeightAlgo::V4R4R2XDLNCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw<in_data_t,
|
||||
wei_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei_device,
|
||||
out,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_WRW_V4R4R4_XDL_NHWC
|
||||
if(algo == ConvBackwardWeightAlgo::V4R4R4XDLNHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
wei_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei_device,
|
||||
out,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW
|
||||
if(algo == ConvBackwardWeightAlgo::V4R4R2XDLATOMICNCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw<
|
||||
in_data_t,
|
||||
wei_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei_device,
|
||||
out,
|
||||
desired_grid_size,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC
|
||||
if(algo == ConvBackwardWeightAlgo::V4R4R4XDLATOMICNHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk<
|
||||
in_data_t,
|
||||
wei_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei_device,
|
||||
out,
|
||||
desired_grid_size,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC
|
||||
if(algo == ConvBackwardWeightAlgo::V4R4R5XDLATOMICNHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk<
|
||||
in_data_t,
|
||||
wei_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei_device,
|
||||
out,
|
||||
desired_grid_size,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_direct_convolution_backward_weights(out,
|
||||
in,
|
||||
wei_host,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
layout);
|
||||
|
||||
check_error(wei_host, wei_device);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "out: ", out.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "wei_device: ", wei_device.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "wei_host : ", wei_host.mData, ",") << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
288
host/driver_offline/src/gemm_driver_offline.cpp
Normal file
288
host/driver_offline/src/gemm_driver_offline.cpp
Normal file
@@ -0,0 +1,288 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "debug.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "gemm_common.hpp"
|
||||
#include "host_gemm.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_gemm_xdlops_mk_kn_mn.hpp"
|
||||
#include "device_gemm_xdlops_mk_nk_mn.hpp"
|
||||
#include "device_gemm_xdlops_km_kn_mn.hpp"
|
||||
#include "device_gemm_xdlops_km_nk_mn.hpp"
|
||||
#include "device_gemm_xdlops_mk_kn_nm.hpp"
|
||||
#include "device_gemm_xdlops_mk_nk_nm.hpp"
|
||||
#include "device_gemm_xdlops_km_kn_nm.hpp"
|
||||
#include "device_gemm_xdlops_km_nk_nm.hpp"
|
||||
|
||||
#define USE_GEMM_XDL_MK_KN_MN 1
|
||||
#define USE_GEMM_XDL_MK_NK_MN 1
|
||||
#define USE_GEMM_XDL_KM_KN_MN 1
|
||||
#define USE_GEMM_XDL_KM_NK_MN 1
|
||||
#define USE_GEMM_XDL_MK_KN_NM 0
|
||||
#define USE_GEMM_XDL_MK_NK_NM 0
|
||||
#define USE_GEMM_XDL_KM_KN_NM 0
|
||||
#define USE_GEMM_XDL_KM_NK_NM 0
|
||||
|
||||
enum GemmAlgo
|
||||
{
|
||||
Xdl_MK_KN_MN, // 0
|
||||
Xdl_MK_NK_MN, // 1
|
||||
Xdl_KM_KN_MN, // 2
|
||||
Xdl_KM_NK_MN, // 3
|
||||
Xdl_MK_KN_NM, // 4
|
||||
Xdl_MK_NK_NM, // 5
|
||||
Xdl_KM_KN_NM, // 6
|
||||
Xdl_KM_NK_NM, // 7
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
if(argc != 12)
|
||||
{
|
||||
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("rest: M, N, K\n");
|
||||
printf("debug_driver_gemm_xdlops_v2r3::M01, debug_driver_gemm_xdlops_v2r3::N01\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
|
||||
const auto algo = static_cast<GemmAlgo>(std::stoi(argv[2]));
|
||||
const bool do_verification = std::stoi(argv[3]);
|
||||
const int init_method = std::stoi(argv[4]);
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
const index_t M = std::stoi(argv[7]);
|
||||
const index_t N = std::stoi(argv[8]);
|
||||
const index_t K = std::stoi(argv[9]);
|
||||
|
||||
debug::debug_driver_gemm_xdlops_v2r3::M01 = std::stoi(argv[10]);
|
||||
debug::debug_driver_gemm_xdlops_v2r3::N01 = std::stoi(argv[11]);
|
||||
|
||||
#if 0
|
||||
using ab_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using c_data_t = float;
|
||||
#elif 1
|
||||
using ab_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using c_data_t = half_t;
|
||||
#elif 1
|
||||
using ab_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
using c_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> a_lengths_host(2), b_lengths_host(2), c_lengths_host(2);
|
||||
std::vector<std::size_t> a_strides_host(2), b_strides_host(2), c_strides_host(2);
|
||||
|
||||
// A
|
||||
if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::MK_NK_MN ||
|
||||
layout == GemmMatrixLayout::MK_KN_NM || layout == GemmMatrixLayout::MK_NK_NM)
|
||||
{
|
||||
a_lengths_host[0] = static_cast<std::size_t>(M);
|
||||
a_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
a_strides_host[0] = static_cast<std::size_t>(K);
|
||||
a_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
a_lengths_host[1] = static_cast<std::size_t>(M);
|
||||
a_strides_host[0] = static_cast<std::size_t>(M);
|
||||
a_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
|
||||
// B
|
||||
if(layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN ||
|
||||
layout == GemmMatrixLayout::MK_NK_NM || layout == GemmMatrixLayout::KM_NK_NM)
|
||||
{
|
||||
b_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
b_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
b_strides_host[0] = static_cast<std::size_t>(K);
|
||||
b_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
else
|
||||
{
|
||||
b_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
b_lengths_host[1] = static_cast<std::size_t>(N);
|
||||
b_strides_host[0] = static_cast<std::size_t>(N);
|
||||
b_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
|
||||
// C
|
||||
if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::KM_KN_MN ||
|
||||
layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
c_lengths_host[0] = static_cast<std::size_t>(M);
|
||||
c_lengths_host[1] = static_cast<std::size_t>(N);
|
||||
c_strides_host[0] = static_cast<std::size_t>(N);
|
||||
c_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
else
|
||||
{
|
||||
c_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
c_lengths_host[1] = static_cast<std::size_t>(M);
|
||||
c_strides_host[0] = static_cast<std::size_t>(M);
|
||||
c_strides_host[1] = static_cast<std::size_t>(1);
|
||||
}
|
||||
|
||||
Tensor<ab_data_t> a(a_lengths_host, a_strides_host);
|
||||
Tensor<ab_data_t> b(b_lengths_host, b_strides_host);
|
||||
Tensor<c_data_t> c_host(c_lengths_host, c_strides_host);
|
||||
Tensor<c_data_t> c_device(c_lengths_host, c_strides_host);
|
||||
|
||||
std::cout << "layout: " << layout << std::endl;
|
||||
ostream_HostTensorDescriptor(a.mDesc, std::cout << "a: ");
|
||||
ostream_HostTensorDescriptor(b.mDesc, std::cout << "b: ");
|
||||
ostream_HostTensorDescriptor(c_host.mDesc, std::cout << "c: ");
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
// no initialization
|
||||
break;
|
||||
case 1:
|
||||
a.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
a.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 4:
|
||||
a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
a.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
|
||||
}
|
||||
|
||||
#if USE_GEMM_XDL_MK_KN_MN
|
||||
if(algo == GemmAlgo::Xdl_MK_KN_MN)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_mk_kn_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_MK_NK_MN
|
||||
if(algo == GemmAlgo::Xdl_MK_NK_MN)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_mk_nk_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_KM_KN_MN
|
||||
if(algo == GemmAlgo::Xdl_KM_KN_MN)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_km_kn_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_KM_NK_MN
|
||||
if(algo == GemmAlgo::Xdl_KM_NK_MN)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_km_nk_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_MK_KN_NM
|
||||
if(algo == GemmAlgo::Xdl_MK_KN_NM)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::MK_KN_NM)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_mk_kn_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_MK_NK_NM
|
||||
if(algo == GemmAlgo::Xdl_MK_NK_NM)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::MK_NK_NM)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_mk_nk_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_KM_KN_NM
|
||||
if(algo == GemmAlgo::Xdl_KM_KN_NM)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::KM_KN_NM)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_km_kn_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_GEMM_XDL_KM_NK_NM
|
||||
if(algo == GemmAlgo::Xdl_KM_NK_NM)
|
||||
{
|
||||
if(layout != GemmMatrixLayout::KM_NK_NM)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
device_gemm_xdlops_km_nk_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_gemm(a, b, c_host, layout);
|
||||
|
||||
check_error(c_host, c_device);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_host : ", c_host.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_device: ", c_device.mData, ",") << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,9 @@
|
||||
#define DEVICE_HPP
|
||||
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
|
||||
@@ -74,7 +77,8 @@ float launch_and_time_kernel(
|
||||
|
||||
timer.End();
|
||||
|
||||
// std::this_thread::sleep_for (std::chrono::microseconds(10));
|
||||
|
||||
return timer.GetElapsedTime() / nrepeat;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
16
host/host_tensor/include/gemm_common.hpp
Normal file
16
host/host_tensor/include/gemm_common.hpp
Normal file
@@ -0,0 +1,16 @@
|
||||
#ifndef GEMM_COMMON_HPP
|
||||
#define GEMM_COMMON_HPP
|
||||
|
||||
enum GemmMatrixLayout
|
||||
{
|
||||
MK_KN_MN, // 0
|
||||
MK_NK_MN, // 1
|
||||
KM_KN_MN, // 2
|
||||
KM_NK_MN, // 3
|
||||
MK_KN_NM, // 4
|
||||
MK_NK_NM, // 5
|
||||
KM_KN_NM, // 6
|
||||
KM_NK_NM, // 7
|
||||
};
|
||||
|
||||
#endif
|
||||
89
host/host_tensor/include/host_conv_bwd_weight.hpp
Normal file
89
host/host_tensor/include/host_conv_bwd_weight.hpp
Normal file
@@ -0,0 +1,89 @@
|
||||
#pragma once
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
template <typename TOut,
|
||||
typename TIn,
|
||||
typename TWei,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void host_direct_convolution_backward_weights(
|
||||
const Tensor<TOut>& out,
|
||||
const Tensor<TIn>& in,
|
||||
Tensor<TWei>& wei,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads&,
|
||||
const ConvTensorLayout layout = ConvTensorLayout::NCHW)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
|
||||
double v = 0;
|
||||
for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n)
|
||||
{
|
||||
for(int ho = 0; ho < out.mDesc.GetLengths()[2]; ++ho)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int wo = 0; wo < out.mDesc.GetLengths()[3]; ++wo)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[3])
|
||||
{
|
||||
v += static_cast<const double>(in(n, c, hi, wi)) *
|
||||
static_cast<const double>(out(n, k, ho, wo));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
wei(k, c, y, x) = v;
|
||||
};
|
||||
|
||||
auto f_kyxc = [&](auto k, auto y, auto x, auto c) {
|
||||
double v = 0;
|
||||
for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n)
|
||||
{
|
||||
for(int ho = 0; ho < out.mDesc.GetLengths()[1]; ++ho)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int wo = 0; wo < out.mDesc.GetLengths()[2]; ++wo)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[2])
|
||||
{
|
||||
v += static_cast<const double>(in(n, hi, wi, c)) *
|
||||
static_cast<const double>(out(n, ho, wo, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
wei(k, y, x, c) = v;
|
||||
};
|
||||
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_kcyx,
|
||||
wei.mDesc.GetLengths()[0],
|
||||
wei.mDesc.GetLengths()[1],
|
||||
wei.mDesc.GetLengths()[2],
|
||||
wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_kyxc,
|
||||
wei.mDesc.GetLengths()[0],
|
||||
wei.mDesc.GetLengths()[1],
|
||||
wei.mDesc.GetLengths()[2],
|
||||
wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
159
host/host_tensor/include/host_gemm.hpp
Normal file
159
host/host_tensor/include/host_gemm.hpp
Normal file
@@ -0,0 +1,159 @@
|
||||
#pragma once
|
||||
#include "host_tensor.hpp"
|
||||
#include "gemm_common.hpp"
|
||||
|
||||
template <typename AType, typename BType, typename CType>
|
||||
void host_gemm(const Tensor<AType>& a,
|
||||
const Tensor<BType>& b,
|
||||
Tensor<CType>& c,
|
||||
const GemmMatrixLayout layout)
|
||||
{
|
||||
if(layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = a.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(k, n));
|
||||
}
|
||||
|
||||
c(m, n) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
auto f_mk_nk_mn = [&](auto m, auto n) {
|
||||
const int K = a.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(n, k));
|
||||
}
|
||||
|
||||
c(m, n) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
auto f_km_kn_mn = [&](auto m, auto n) {
|
||||
const int K = a.mDesc.GetLengths()[0];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(k, n));
|
||||
}
|
||||
|
||||
c(m, n) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
auto f_km_nk_mn = [&](auto m, auto n) {
|
||||
const int K = a.mDesc.GetLengths()[0];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(n, k));
|
||||
}
|
||||
|
||||
c(m, n) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::MK_KN_NM)
|
||||
{
|
||||
auto f_mk_kn_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(k, n));
|
||||
}
|
||||
|
||||
c(n, m) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::MK_NK_NM)
|
||||
{
|
||||
auto f_mk_nk_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(n, k));
|
||||
}
|
||||
|
||||
c(n, m) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_KN_NM)
|
||||
{
|
||||
auto f_km_kn_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[0];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(k, n));
|
||||
}
|
||||
|
||||
c(n, m) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_NK_NM)
|
||||
{
|
||||
auto f_km_nk_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[0];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(n, k));
|
||||
}
|
||||
|
||||
c(n, m) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,17 @@ struct GeneratorTensor_1
|
||||
}
|
||||
};
|
||||
|
||||
struct GeneratorTensor_0
|
||||
{
|
||||
int value = 0;
|
||||
|
||||
template <typename... Is>
|
||||
float operator()(Is...)
|
||||
{
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
struct GeneratorTensor_2
|
||||
{
|
||||
int min_value = 0;
|
||||
|
||||
@@ -9,8 +9,8 @@ struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
|
||||
int NPerBlock;
|
||||
int KPerBlock;
|
||||
|
||||
int MPerWave;
|
||||
int NPerWave;
|
||||
int MPerXDL;
|
||||
int NPerXDL;
|
||||
int K1;
|
||||
|
||||
int MRepeat;
|
||||
@@ -45,8 +45,8 @@ static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
|
||||
128, // MPerBlock,
|
||||
128, // NPerBlock,
|
||||
4, // KPerBlock,
|
||||
32, // MPerWave,
|
||||
32, // NPerWave,
|
||||
32, // MPerXDL,
|
||||
32, // NPerXDL,
|
||||
4, // K1,
|
||||
2, // MRepeat,
|
||||
2, // NRepeat,
|
||||
|
||||
14
script/docker-rocm4.3.1.sh
Executable file
14
script/docker-rocm4.3.1.sh
Executable file
@@ -0,0 +1,14 @@
|
||||
WORKSPACE=$1
|
||||
echo "workspace: " $WORKSPACE
|
||||
|
||||
docker run \
|
||||
-it \
|
||||
--rm \
|
||||
--privileged \
|
||||
--group-add sudo \
|
||||
-w /root/workspace \
|
||||
-v $WORKSPACE:/root/workspace \
|
||||
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
|
||||
/bin/bash
|
||||
|
||||
#--network host \
|
||||
146
script/run.sh
146
script/run.sh
@@ -4,21 +4,12 @@
|
||||
export ROCR_VISIBLE_DEVICE=0
|
||||
export GPU_DEVICE_ORDINAL=0
|
||||
|
||||
## Boost
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
|
||||
|
||||
## Compiling
|
||||
#export OLC_DEBUG_HIP_VERBOSE=1
|
||||
#export OLC_DEBUG_HIP_DUMP=1
|
||||
#export OLC_DEBUG_SAVE_TEMP_DIR=1
|
||||
|
||||
make -j conv_fwd_driver_offline
|
||||
make -j conv_bwd_driver_offline
|
||||
make -j conv_fwd_driver_online
|
||||
|
||||
#rm -rf /root/_hip_binary_kernels_/
|
||||
#rm -rf /tmp/olCompile*
|
||||
#make -j conv_bwd_driver_offline
|
||||
#make -j conv_wrw_driver_offline
|
||||
#make -j gemm_driver_offline
|
||||
|
||||
DRIVER="./host/driver_offline/conv_fwd_driver_offline"
|
||||
LAYOUT=$1
|
||||
ALGO=$2
|
||||
VERIFY=$3
|
||||
@@ -26,22 +17,121 @@ INIT=$4
|
||||
LOG=$5
|
||||
REPEAT=$6
|
||||
|
||||
################################################ layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3
|
||||
./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1
|
||||
#M01=$7
|
||||
#N01=$8
|
||||
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0
|
||||
KBATCH=$7
|
||||
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0
|
||||
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0
|
||||
######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1
|
||||
|
||||
#./host/driver_offline/conv_bwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0
|
||||
|
||||
#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0
|
||||
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
|
||||
######### layout algo verify init log repeat M___ N___ K___
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01
|
||||
|
||||
# Resnet50
|
||||
######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 1024 1 1 14 14 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 28 28 1 1 1 1 1 1 1 1
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 128 1 1 28 28 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 58 58 2 2 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 256 1 1 14 14 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 256 1 1 56 56 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 256 1 1 56 56 2 2 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 256 1 1 56 56 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 512 1 1 28 28 2 2 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 512 1 1 28 28 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 512 1 1 28 28 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 512 1 1 7 7 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 64 1 1 56 56 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 1 1 56 56 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 3 3 56 56 1 1 1 1 1 1 1 1
|
||||
|
||||
# 256x128x32 c64
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 56
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 56
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 224
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH
|
||||
|
||||
|
||||
|
||||
# 128x128x32 c64
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 28
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 112
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 112
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 448
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 28
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 224
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 112
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH
|
||||
|
||||
|
||||
# 128x64x32 c64
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112
|
||||
|
||||
# 64x128x32 c64
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
|
||||
# 64x64x32 c32
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 112
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 448
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 448
|
||||
|
||||
Reference in New Issue
Block a user