mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Clean-up the headers (#713)
* fix headers for gpu instances * remove unused headers --------- Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -1,275 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Number of GEMMs = YTilde * XTilde
|
||||
// GemmM = C
|
||||
// GemmN = N * HTildeSlice * WTildeSlice
|
||||
// GemmK = K * YDotSlice * XDotSlice
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t IYTildeValue,
|
||||
index_t IXTildeValue,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<IYTildeValue>,
|
||||
Number<IXTildeValue>,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
constexpr auto IYTilde = Number<IYTildeValue>{};
|
||||
constexpr auto IXTilde = Number<IXTildeValue>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
const auto YDot = math::integer_divide_ceil(Y, YTilde);
|
||||
const auto XDot = math::integer_divide_ceil(X, XTilde);
|
||||
|
||||
const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
|
||||
const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
|
||||
|
||||
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
|
||||
const auto IHTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
|
||||
const auto IWTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
|
||||
|
||||
const auto IHTildeSliceEnd =
|
||||
math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
|
||||
const auto IWTildeSliceEnd =
|
||||
math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
|
||||
|
||||
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
|
||||
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
|
||||
|
||||
// GemmK is different for each GEMM
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - IYTilde, YTilde);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - IXTilde, XTilde);
|
||||
|
||||
const auto K1 = GemmK1;
|
||||
const auto K0 = K / K1;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_y_x_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_embed_transform(make_tuple(YDot, YTilde),
|
||||
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, XTilde),
|
||||
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
|
||||
transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_freeze_transform(IYTilde),
|
||||
make_freeze_transform(IXTilde),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0, 1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<>{},
|
||||
Sequence<>{},
|
||||
Sequence<4>{}));
|
||||
|
||||
#if 1
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||
make_pass_through_transform(C),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#else
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
||||
make_pass_through_transform(C),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#endif
|
||||
|
||||
// output tensor
|
||||
// this add padding check
|
||||
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ho_wo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, I0, I0),
|
||||
make_pad_transform(Wo, I0, I0),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
|
||||
out_n_hop_wop_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YDot, HTilde),
|
||||
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, WTilde),
|
||||
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_unmerge_transform(make_tuple(K0, K1))),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5, 6>{}));
|
||||
|
||||
#if 1
|
||||
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#else
|
||||
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
||||
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#endif
|
||||
|
||||
// input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YTilde, HTilde),
|
||||
make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(XTilde, WTilde),
|
||||
make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_freeze_transform(IYTilde),
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_freeze_transform(IXTilde),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<>{},
|
||||
Sequence<1>{},
|
||||
Sequence<>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{}));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_n_htildeslice_wtildeslice_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(C),
|
||||
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
out_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,355 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// A: out
|
||||
// B: wei
|
||||
// C: in
|
||||
// Number of GEMMs = YTilde * XTilde
|
||||
// GemmM = N * HTildeSlice * WTildeSlice
|
||||
// GemmN = C
|
||||
// GemmK = K * YDotSlice * XDotSlice
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
typename IYTilde,
|
||||
typename IXTilde,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
IYTilde i_ytilde,
|
||||
IXTilde i_xtilde,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
const auto YDot = math::integer_divide_ceil(Y, YTilde);
|
||||
const auto XDot = math::integer_divide_ceil(X, XTilde);
|
||||
|
||||
const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
|
||||
const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
|
||||
|
||||
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
|
||||
const auto IHTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
|
||||
const auto IWTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
|
||||
|
||||
const auto IHTildeSliceEnd =
|
||||
math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
|
||||
const auto IWTildeSliceEnd =
|
||||
math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
|
||||
|
||||
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
|
||||
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
|
||||
|
||||
// GemmK is different for each GEMM
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
|
||||
|
||||
const auto K1 = GemmK1;
|
||||
const auto K0 = K / K1;
|
||||
|
||||
// A: output tensor
|
||||
// this add padding check
|
||||
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ho_wo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, I0, I0),
|
||||
make_pad_transform(Wo, I0, I0),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
|
||||
out_n_hop_wop_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YDot, HTilde),
|
||||
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, WTilde),
|
||||
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_unmerge_transform(make_tuple(K0, K1))),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5, 6>{}));
|
||||
|
||||
#if 1
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#else
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
||||
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#endif
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_y_x_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_embed_transform(make_tuple(YDot, YTilde),
|
||||
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, XTilde),
|
||||
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
|
||||
transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_freeze_transform(i_ytilde),
|
||||
make_freeze_transform(i_xtilde),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0, 1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<>{},
|
||||
Sequence<>{},
|
||||
Sequence<4>{}));
|
||||
|
||||
#if 1
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||
make_pass_through_transform(C),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#else
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
||||
make_pass_through_transform(C),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#endif
|
||||
|
||||
// C: input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YTilde, HTilde),
|
||||
make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(XTilde, WTilde),
|
||||
make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_freeze_transform(i_ytilde),
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_freeze_transform(i_xtilde),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<>{},
|
||||
Sequence<1>{},
|
||||
Sequence<>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{}));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_n_htildeslice_wtildeslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
// A: out
|
||||
// B: wei
|
||||
// C: in
|
||||
// Number of GEMMs = 1
|
||||
// GemmM = N * Ho * Wo
|
||||
// GemmN = C
|
||||
// GemmK = K
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1(
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const TensorDescriptor<Wei...>& /* wei_k_y_x_c_grid_desc */,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto K1 = GemmK1;
|
||||
const auto K0 = K / K1;
|
||||
|
||||
// A: output tensor
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_unmerge_transform(make_tuple(K0, K1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: input tensor
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_freeze_transform(I0),
|
||||
make_freeze_transform(I0),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,150 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmK = N * Ho * Wo
|
||||
// GemmN = C * Y * X
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmK1Value,
|
||||
typename GemmKBatchType,
|
||||
typename GemmKPadType>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad(
|
||||
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<GemmK1Value>,
|
||||
GemmKBatchType GemmKBatch,
|
||||
GemmKPadType GemmKPad)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = K;
|
||||
const auto GemmN = C * Y * X;
|
||||
const auto GemmKTotal = N * Ho * Wo;
|
||||
const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1);
|
||||
|
||||
// A: output tensor
|
||||
const auto out_gemmktotal_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// B: input tensor
|
||||
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_gemmktotal_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,132 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmK = N * Ho * Wo
|
||||
// GemmN = C * Y * X
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
||||
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = K;
|
||||
const auto GemmN = C * Y * X;
|
||||
const auto GemmK = N * Ho * Wo;
|
||||
const auto GemmK0 = GemmK / GemmK1;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,150 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// A: in
|
||||
// B: wei
|
||||
// C: out
|
||||
// GemmM = N * Ho * Wo
|
||||
// GemmN = K
|
||||
// GemmK = Y * X * C
|
||||
template <typename... In,
|
||||
typename... Wei,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmK1Value,
|
||||
typename GemmKBatchType,
|
||||
typename GemmKPadType>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad(
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<GemmK1Value>,
|
||||
GemmKBatchType GemmKBatch,
|
||||
GemmKPadType GemmKPad)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = Y * X * C;
|
||||
const auto GemmN = K;
|
||||
const auto GemmKTotal = N * Ho * Wo;
|
||||
const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1);
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmktotal_gemmm_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto in_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// B: output tensor
|
||||
const auto out_gemmktotal_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
return make_tuple(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,135 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// A: in
|
||||
// B: wei
|
||||
// C: out
|
||||
// GemmM = N * Ho * Wo
|
||||
// GemmN = K
|
||||
// GemmK = Y * X * C
|
||||
template <typename... In,
|
||||
typename... Wei,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = Y * X * C;
|
||||
const auto GemmN = K;
|
||||
const auto GemmK = N * Ho * Wo;
|
||||
const auto GemmK0 = GemmK / GemmK1;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmm_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(in_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: output tensor
|
||||
const auto out_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmk0_gemmn_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
out_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,147 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// A: out
|
||||
// B: in
|
||||
// C: wei
|
||||
// GemmM = K
|
||||
// GemmN = Y * X * C
|
||||
// GemmKTotal = N * Ho * Wo
|
||||
template <typename... In,
|
||||
typename... Wei,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmK1Value,
|
||||
typename GemmKBatchType,
|
||||
typename GemmKPadType>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad(
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<GemmK1Value>,
|
||||
GemmKBatchType GemmKBatch,
|
||||
GemmKPadType GemmKPad)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = K;
|
||||
const auto GemmN = Y * X * C;
|
||||
const auto GemmKTotal = N * Ho * Wo;
|
||||
const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1);
|
||||
|
||||
// A: output tensor
|
||||
const auto out_gemmktotal_gemmm_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// B: input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmktotal_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,260 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(
|
||||
const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_global_desc =
|
||||
transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(
|
||||
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc);
|
||||
}
|
||||
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
|
||||
const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
assert(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0);
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_global_desc =
|
||||
transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(
|
||||
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc);
|
||||
}
|
||||
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1(
|
||||
const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
|
||||
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
|
||||
InRightPadW == 0);
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(
|
||||
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,179 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad(
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
return make_tuple(
|
||||
wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1(
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
|
||||
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
|
||||
InRightPadW == 0);
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
return make_tuple(
|
||||
wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,132 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
||||
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = K;
|
||||
const auto GemmN = N * Ho * Wo;
|
||||
const auto GemmK = C * Y * X;
|
||||
const auto GemmK0 = GemmK / GemmK1;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,132 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = K;
|
||||
const auto GemmN = N * Ho * Wo;
|
||||
const auto GemmK = C * Y * X;
|
||||
const auto GemmK0 = GemmK / GemmK1;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,134 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// A: in
|
||||
// B: wei
|
||||
// C: out
|
||||
// GemmM = N * Ho * Wo
|
||||
// GemmN = K
|
||||
// GemmK = Y * X * C
|
||||
template <typename... In,
|
||||
typename... Wei,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = N * Ho * Wo;
|
||||
const auto GemmN = K;
|
||||
const auto GemmK = Y * X * C;
|
||||
const auto GemmK0 = GemmK / GemmK1;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmm_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(in_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(wei_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,135 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM0 = 1
|
||||
// GemmM1 = K
|
||||
// GemmN0 = N0
|
||||
// GemmN1 = (N / N0) * Ho * Wo
|
||||
// GemmK0 = (C / C0) * Y * X
|
||||
// GemmK1 = C0
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
typename N0Type,
|
||||
typename C0Type>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
||||
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const N0Type& N0,
|
||||
const C0Type& C0)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto N1 = N / N0;
|
||||
const auto C1 = C / C0;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gk0_gm0_gm1_gk1_grid_desc =
|
||||
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(I1, K)),
|
||||
make_unmerge_transform(make_tuple(C0, C1 * Y * X))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
|
||||
make_unmerge_transform(make_tuple(C0, C1)),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
|
||||
|
||||
const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_tensor_descriptor(
|
||||
in_n0_n1_c0_c1_y_ho_x_wo_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C1, Y, X)),
|
||||
make_pass_through_transform(N0),
|
||||
make_merge_transform(make_tuple(N1, Ho, Wo)),
|
||||
make_pass_through_transform(C0)),
|
||||
make_tuple(Sequence<3, 4, 6>{}, Sequence<0>{}, Sequence<1, 5, 7>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_n_k_howo_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo));
|
||||
|
||||
const auto out_n0_n1_1_k_howo_grid_desc =
|
||||
transform_tensor_descriptor(out_n_k_howo_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
|
||||
make_unmerge_transform(make_tuple(I1, K)),
|
||||
make_pass_through_transform(Ho * Wo)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||
|
||||
const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_tensor_descriptor(
|
||||
out_n0_n1_1_k_howo_grid_desc,
|
||||
make_tuple(make_pass_through_transform(I1),
|
||||
make_pass_through_transform(K),
|
||||
make_pass_through_transform(N0),
|
||||
make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))),
|
||||
make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
return make_tuple(
|
||||
wei_gk0_gm0_gm1_gk1_grid_desc, in_gk0_gn0_gn1_gk1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,586 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatDsPointer,
|
||||
typename FloatE,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2ETileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_bias_e_permute(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatDsPointer p_ds_grid,
|
||||
FloatE* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = p_e_grid;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = block_2_etile_map;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// input : A[M, K], or A[K, N]
|
||||
// input : B[K, N], or A[N, K]
|
||||
// input : D0[M, N], D1[M, N], ...
|
||||
// output : E[M, N]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CDELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGemmBiasEPermute_Xdl;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
static constexpr index_t NumDTensor = 1;
|
||||
|
||||
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
|
||||
static auto MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1 d_e_grid_desc)
|
||||
{
|
||||
index_t M0 = d_e_grid_desc.M0_;
|
||||
index_t M1 = d_e_grid_desc.M1_;
|
||||
index_t M2 = d_e_grid_desc.M2_;
|
||||
index_t N0 = d_e_grid_desc.N0_;
|
||||
index_t N1 = d_e_grid_desc.N1_;
|
||||
|
||||
index_t stride_M0 = d_e_grid_desc.stride_M0_;
|
||||
index_t stride_M1 = d_e_grid_desc.stride_M1_;
|
||||
index_t stride_M2 = d_e_grid_desc.stride_M2_;
|
||||
index_t stride_N0 = d_e_grid_desc.stride_N0_;
|
||||
index_t stride_N1 = d_e_grid_desc.stride_N1_;
|
||||
|
||||
const auto e_grid_desc_mraw_nraw = [&]() {
|
||||
const auto e_grid_desc_m0_m1_m2_n0_n1 = make_naive_tensor_descriptor(
|
||||
make_tuple(M0, M1, M2, N0, N1),
|
||||
make_tuple(stride_M0, stride_M1, stride_M2, stride_N0, stride_N1));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
e_grid_desc_m0_m1_m2_n0_n1,
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2)),
|
||||
make_merge_transform(make_tuple(N0, N1))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}();
|
||||
|
||||
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
|
||||
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1{}));
|
||||
|
||||
using DsGridDesc_M_N = Tuple<EGridDesc_M_N>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ck::Tuple<DDataType>,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
|
||||
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const void* p_a_grid,
|
||||
const void* p_b_grid,
|
||||
const void* p_d_grid,
|
||||
void* p_e_grid,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc,
|
||||
DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
|
||||
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_grid_desc)},
|
||||
a_grid_desc_ak0_m_ak1_{
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
|
||||
b_grid_desc_bk0_n_bk1_{
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
|
||||
if(MRaw != d_grid_desc.M0_ * d_grid_desc.M1_ * d_grid_desc.M2_)
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
if(NRaw != d_grid_desc.N0_ * d_grid_desc.N1_)
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
// populate pointer, desc for Ds
|
||||
// D pointer
|
||||
p_ds_grid_(I0) = static_cast<const DDataType*>(p_d_grid);
|
||||
|
||||
// D desc
|
||||
ds_grid_desc_m_n_(I0) = DeviceOp::MakeEGridDescriptor_M_N(d_grid_desc);
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
|
||||
b_grid_desc_n_k_,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_))
|
||||
{
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n_);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_(I0) =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n_[I0]);
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
// pointers
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
typename GridwiseGemm::DsGridPointer p_ds_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-e-tile map
|
||||
Block2ETileMap block_2_etile_map_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
const auto kernel = kernel_gemm_bias_e_permute<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2ETileMap,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_);
|
||||
};
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_d,
|
||||
void* p_e,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc,
|
||||
DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_d,
|
||||
p_e,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
StrideA,
|
||||
StrideB,
|
||||
d_grid_desc,
|
||||
e_grid_desc,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_d,
|
||||
void* p_e,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc,
|
||||
DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_d,
|
||||
p_e,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
StrideA,
|
||||
StrideB,
|
||||
d_grid_desc,
|
||||
e_grid_desc,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmBiasEPermute_Xdl"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerXDL << ", "
|
||||
<< NPerXDL << ", "
|
||||
<< MXdlPerWave << ", "
|
||||
<< NXdlPerWave << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< ABlockTransferDstScalarPerVector_K1 << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferDstScalarPerVector_K1 << ", "
|
||||
<< CShuffleMXdlPerWavePerShuffle << ", "
|
||||
<< CShuffleNXdlPerWavePerShuffle << ", "
|
||||
<< CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,662 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
|
||||
#define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_dlops_v2r3.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseContraction,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AGridDesc_GK0_GM0_GM10_GM11_GK1,
|
||||
typename BGridDesc_GK0_GN0_GN10_GN11_GK1,
|
||||
typename CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1,
|
||||
typename CGridBlockCluster_BlockId_To_GM10_GN10,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_contraction_dlops_v1r2(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_GK0_GM0_GM10_GM11_GK1 a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
const BGridDesc_GK0_GN0_GN10_GN11_GK1 b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
const CGridBlockCluster_BlockId_To_GM10_GN10 c_grid_block_cluster_blockid_to_gm10_gn10)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseContraction::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_GK0_GM0_GM1_GK1,
|
||||
typename BGridDesc_GK0_GN0_GN1_GK1,
|
||||
typename CGridDesc_GM0_GM1_GN0_GN1,
|
||||
index_t GM1PerBlockGM11,
|
||||
index_t GN1PerBlockGN11,
|
||||
index_t GK0PerBlock,
|
||||
index_t BM1PerThreadBM11,
|
||||
index_t BN1PerThreadBN11,
|
||||
index_t BK0PerThread,
|
||||
typename BM10BN10ThreadClusterBM10Xs,
|
||||
typename BM10BN10ThreadClusterBN10Xs,
|
||||
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
// GM0 and GN0 need to known at compile-time
|
||||
static constexpr auto GM0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I0);
|
||||
static constexpr auto GN0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I2);
|
||||
static constexpr auto GK1 = AGridDesc_GK0_GM0_GM1_GK1{}.GetLength(I3);
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// lds max alignment
|
||||
// TODO: part of them should be moved into blockwise-gemm
|
||||
// TODO: change this. I think it needs multi-dimensional alignment
|
||||
constexpr auto max_lds_align = GK1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
|
||||
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
|
||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(GM0)>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<decltype(GN0)>>::value,
|
||||
"wrong! GM0 and GN0 need to be known at compile-time");
|
||||
|
||||
const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2);
|
||||
const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2);
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
|
||||
return (
|
||||
(GM0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I0) &&
|
||||
GM1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1) &&
|
||||
GN0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I2) &&
|
||||
GN1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3) &&
|
||||
GM0 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I1) &&
|
||||
GM1 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2) &&
|
||||
GN0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I1) &&
|
||||
GN1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2) &&
|
||||
GK0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0) &&
|
||||
GK1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I3)) &&
|
||||
(GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % GK0PerBlock == 0));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
|
||||
{
|
||||
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
|
||||
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
|
||||
|
||||
constexpr index_t GM11 = GM1PerBlockGM11;
|
||||
constexpr index_t GN11 = GN1PerBlockGN11;
|
||||
|
||||
const index_t GM10 = GM1 / GM11;
|
||||
const index_t GN10 = GN1 / GN11;
|
||||
|
||||
const index_t grid_size = GM10 * GN10;
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK0)
|
||||
{
|
||||
const bool has_main_k_block_loop = (GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1;
|
||||
|
||||
return has_main_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK0)
|
||||
{
|
||||
const bool has_double_tail_k_block_loop = (GK0 / GK0PerBlock) % 2 == 0;
|
||||
|
||||
return has_double_tail_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
|
||||
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1)
|
||||
{
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
|
||||
const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2);
|
||||
|
||||
const auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
const auto GM10 = GM1 / GM11;
|
||||
|
||||
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_tensor_descriptor(
|
||||
a_grid_desc_gk0_gm0_gm1_gk1,
|
||||
make_tuple(make_pass_through_transform(GK0),
|
||||
make_pass_through_transform(GM0),
|
||||
make_unmerge_transform(make_tuple(GM10, GM11)),
|
||||
make_pass_through_transform(GK1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||
|
||||
return a_grid_desc_gk0_gm0_gm10_gm11_gk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
|
||||
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1)
|
||||
{
|
||||
const auto GK0 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0);
|
||||
const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2);
|
||||
|
||||
const auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_tensor_descriptor(
|
||||
b_grid_desc_gk0_gn0_gn1_gk1,
|
||||
make_tuple(make_pass_through_transform(GK0),
|
||||
make_pass_through_transform(GN0),
|
||||
make_unmerge_transform(make_tuple(GN10, GN11)),
|
||||
make_pass_through_transform(GK1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||
|
||||
return b_grid_desc_gk0_gn0_gn10_gn11_gk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
|
||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
|
||||
{
|
||||
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
|
||||
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
|
||||
|
||||
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
|
||||
const auto GM10 = GM1 / GM11;
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
constexpr auto BM = GM0 * GM11;
|
||||
constexpr auto BN = GN0 * GN11;
|
||||
|
||||
constexpr auto BM1 =
|
||||
Number<container_reduce(BM10BN10ThreadClusterBM10Xs{}, math::multiplies{}, I1) *
|
||||
BM1PerThreadBM11>{};
|
||||
constexpr auto BN1 =
|
||||
Number<container_reduce(BM10BN10ThreadClusterBN10Xs{}, math::multiplies{}, I1) *
|
||||
BN1PerThreadBN11>{};
|
||||
|
||||
constexpr auto BM0 = BM / BM1;
|
||||
constexpr auto BN0 = BN / BN1;
|
||||
|
||||
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_tensor_descriptor(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1,
|
||||
make_tuple(make_pass_through_transform(GM0),
|
||||
make_unmerge_transform(make_tuple(GM10, GM11)),
|
||||
make_pass_through_transform(GN0),
|
||||
make_unmerge_transform(make_tuple(GN10, GN11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto c_gm10_bm_gn10_bn_grid_desc = transform_tensor_descriptor(
|
||||
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GM10),
|
||||
make_merge_transform(make_tuple(GM0, GM11)),
|
||||
make_pass_through_transform(GN10),
|
||||
make_merge_transform(make_tuple(GN0, GN11))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_tensor_descriptor(
|
||||
c_gm10_bm_gn10_bn_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GM10),
|
||||
make_unmerge_transform(make_tuple(BM0, BM1)),
|
||||
make_pass_through_transform(GN10),
|
||||
make_unmerge_transform(make_tuple(BN0, BN1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
|
||||
|
||||
return c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeCGridBlockCluster_BlockId_To_GM10_GN10(
|
||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
|
||||
{
|
||||
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
|
||||
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
|
||||
|
||||
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
|
||||
const auto GM10 = GM1 / GM11;
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
const auto c_grid_block_cluster_blockid_to_gm10_gn10 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(GM10, GN10))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return c_grid_block_cluster_blockid_to_gm10_gn10;
|
||||
}
|
||||
|
||||
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
|
||||
decltype(MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(AGridDesc_GK0_GM0_GM1_GK1{}));
|
||||
using BGridDesc_GK0_GN0_GN10_GN11_GK1 =
|
||||
decltype(MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(BGridDesc_GK0_GN0_GN1_GK1{}));
|
||||
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 =
|
||||
decltype(MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(CGridDesc_GM0_GM1_GN0_GN1{}));
|
||||
using CGridBlockCluster_BlockId_To_GM10_GN10 =
|
||||
decltype(MakeCGridBlockCluster_BlockId_To_GM10_GN10(CGridDesc_GM0_GM1_GN0_GN1{}));
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AGridDesc_GK0_GM0_GM10_GM11_GK1& a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
const BGridDesc_GK0_GN0_GN10_GN11_GK1& b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1& c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
const CGridBlockCluster_BlockId_To_GM10_GN10& c_grid_block_cluster_blockid_to_gm10_gn10,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize());
|
||||
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0);
|
||||
|
||||
// divide block work by [GM10, GN10]
|
||||
const auto c_gm10_gn10_block_cluster_idx =
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10.CalculateBottomIndex(
|
||||
make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force index data into SGPR
|
||||
const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]);
|
||||
const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]);
|
||||
|
||||
// lds max alignment
|
||||
// TODO: part of them should be moved into blockwise-gemm
|
||||
// TODO: change this. I think it needs multi-dimensional alignment
|
||||
constexpr auto max_lds_align = GK1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// A matrix in LDS memory for blockwise GEMM
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_block_desc_gk0_bm_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory for blockwise GEMM
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_block_desc_gk0_bn_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
|
||||
|
||||
static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() ==
|
||||
a_block_desc_gk0_bm_gk1.GetElementSpaceSize() &&
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize() ==
|
||||
b_block_desc_gk0_bn_gk1.GetElementSpaceSize(),
|
||||
"wrong!");
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1),
|
||||
decltype(a_block_desc_gk0_gm0_gm10_gm11_gk1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // SrcVectorTensorLengths
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // DstVectorTensorLengths
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
make_multi_index(0, 0, igm10, 0, 0),
|
||||
a_block_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
make_multi_index(0, 0, 0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1),
|
||||
decltype(b_block_desc_gk0_gn0_gn10_gn11_gk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // SrcVectorTensorLengths
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // DstVectorTensorLengths
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
make_multi_index(0, 0, ign10, 0, 0),
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
make_multi_index(0, 0, 0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[GK0PerBlock, GM1PerBlockGM11] is in LDS
|
||||
// b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS
|
||||
// c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_gk0_bm_gk1),
|
||||
decltype(b_block_desc_gk0_bn_gk1),
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM10Xs,
|
||||
BM10BN10ThreadClusterBN10Xs,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11>{};
|
||||
|
||||
constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 =
|
||||
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
|
||||
|
||||
constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = make_naive_tensor_descriptor_packed(
|
||||
sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block_double = p_shared_block;
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
|
||||
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_thread_desc_bm0_bm1_bn0_bn1),
|
||||
decltype(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)>{}
|
||||
.Run(c_thread_desc_bm0_bm1_bn0_bn1,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
|
||||
}
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t gk0_block_on_grid = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1,
|
||||
a_block_even_buf,
|
||||
b_block_even_buf,
|
||||
c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
|
||||
|
||||
gk0_block_on_grid += 2 * GK0PerBlock;
|
||||
} while(gk0_block_on_grid < GK0 - 2 * GK0PerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0]>{},
|
||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1]>{},
|
||||
I1,
|
||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2]>{},
|
||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>{}));
|
||||
|
||||
const auto c_thread_origin_on_block_bm0_bm1_bn0_bn1 =
|
||||
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||
get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1),
|
||||
decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1),
|
||||
Sequence<1,
|
||||
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0],
|
||||
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1],
|
||||
1,
|
||||
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2],
|
||||
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
false>{c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
make_multi_index(igm10,
|
||||
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I0],
|
||||
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I1],
|
||||
ign10,
|
||||
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I2],
|
||||
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I3])}
|
||||
.Run(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_buf,
|
||||
CGridStepHacks{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,608 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
|
||||
#define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_dlops_v2r2.hpp"
|
||||
#include "blockwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AKM0M1GridDesc,
|
||||
typename BKN0N1GridDesc,
|
||||
typename CM0M10M11N0N10N11GridDesc,
|
||||
typename CBlockIdToM0N0BlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_dlops_v1r2(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AKM0M1GridDesc a_k_m0_m1_grid_desc,
|
||||
const BKN0N1GridDesc b_k_n0_n1_grid_desc,
|
||||
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const CBlockIdToM0N0BlockClusterAdaptor cblockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
cblockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AKMGridDesc,
|
||||
typename BKNGridDesc,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlockM1,
|
||||
index_t NPerBlockN1,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThreadM111,
|
||||
index_t N1PerThreadN111,
|
||||
index_t KPerThread,
|
||||
index_t M11N11ThreadClusterM1100,
|
||||
index_t M11N11ThreadClusterN1100,
|
||||
index_t M11N11ThreadClusterM1101,
|
||||
index_t M11N11ThreadClusterN1101,
|
||||
typename ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_M1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_N1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
struct GridwiseGemmDlops_km_kn_mn_v1r2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M1>{},
|
||||
Number<BBlockTransferDstScalarPerVector_N1>{},
|
||||
Number<M1PerThreadM111>{},
|
||||
Number<N1PerThreadN111>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size =
|
||||
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CheckValidity(const AKMGridDesc& a_k_m_grid_desc,
|
||||
const BKNGridDesc& b_k_n_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
|
||||
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K == b_k_n_grid_desc.GetLength(I0)) &&
|
||||
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K % KPerBlock == 0);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1);
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
|
||||
|
||||
return has_main_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K)
|
||||
{
|
||||
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
|
||||
|
||||
return has_double_tail_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAKM0M1GridDescriptor(const AKMGridDesc& a_k_m_grid_desc)
|
||||
{
|
||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||
|
||||
const auto M1 = Number<MPerBlockM1>{};
|
||||
const auto M0 = M / M1;
|
||||
|
||||
const auto a_k_m0_m1_grid_desc = transform_tensor_descriptor(
|
||||
a_k_m_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
return a_k_m0_m1_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBKN0N1GridDescriptor(const BKNGridDesc& b_k_n_grid_desc)
|
||||
{
|
||||
const auto K = b_k_n_grid_desc.GetLength(I0);
|
||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||
|
||||
const auto N1 = Number<NPerBlockN1>{};
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto b_k_n0_n1_grid_desc = transform_tensor_descriptor(
|
||||
b_k_n_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
return b_k_n0_n1_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlockM1>{};
|
||||
constexpr auto N1 = Number<NPerBlockN1>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
constexpr auto M11 =
|
||||
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
|
||||
constexpr auto N11 =
|
||||
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
|
||||
|
||||
constexpr auto M10 = M1 / M11;
|
||||
constexpr auto N10 = N1 / N11;
|
||||
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
|
||||
c_m_n_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
|
||||
make_unmerge_transform(make_tuple(N0, N10, N11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
|
||||
|
||||
return c_m0_m10_m11_n0_n10_n11_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlockM1>{};
|
||||
constexpr auto N1 = Number<NPerBlockN1>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto cblockid_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return cblockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{}));
|
||||
using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{}));
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{}));
|
||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||
decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{}));
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AKM0M1GridDesc& a_k_m0_m1_grid_desc,
|
||||
const BKN0N1GridDesc& b_k_n0_n1_grid_desc,
|
||||
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
|
||||
|
||||
const auto K = a_k_m0_m1_grid_desc.GetLength(I0);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto c_m0_n0_block_cluster_idx =
|
||||
cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force index data into SGPR
|
||||
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
|
||||
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M1>{},
|
||||
Number<BBlockTransferDstScalarPerVector_N1>{},
|
||||
Number<M1PerThreadM111>{},
|
||||
Number<N1PerThreadN111>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m0_m1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n0_n1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}), max_lds_align);
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<KPerBlock, 1, MPerBlockM1>,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k_m0_m1_grid_desc),
|
||||
decltype(a_k_m0_m1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(a_k_m0_m1_grid_desc,
|
||||
make_multi_index(0, im0, 0),
|
||||
a_k_m0_m1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<KPerBlock, 1, NPerBlockN1>,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k_n0_n1_grid_desc),
|
||||
decltype(b_k_n0_n1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_k_n0_n1_grid_desc,
|
||||
make_multi_index(0, in0, 0),
|
||||
b_k_n0_n1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlockM1] is in LDS
|
||||
// b_mtx[KPerBlocl, NPerBlockN1] is in LDS
|
||||
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k_m_block_desc),
|
||||
decltype(b_k_n_block_desc),
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM1100,
|
||||
M11N11ThreadClusterN1100,
|
||||
M11N11ThreadClusterM1101,
|
||||
M11N11ThreadClusterN1101,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
|
||||
|
||||
constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size =
|
||||
math::integer_least_multiple(a_k_m0_m1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size =
|
||||
math::integer_least_multiple(b_k_n0_n1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block_double = p_shared_block;
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
|
||||
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_m10_m11_n10_n11_thread_desc),
|
||||
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
|
||||
.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_k_m0_m1_global_step_hacks = AGridStepHacks{};
|
||||
constexpr auto b_k_n0_n1_global_step_hacks = BGridStepHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_k_m0_m1_global_move_slice_window_step_hack =
|
||||
AGridMoveSliceWindowStepHacks{};
|
||||
constexpr auto b_k_n0_n1_global_move_slice_window_step_hack =
|
||||
BGridMoveSliceWindowStepHacks{};
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_k_m0_m1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_k_n0_n1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
|
||||
}
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
a_block_even_buf,
|
||||
b_block_even_buf,
|
||||
c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
|
||||
|
||||
k_block_data_begin += 2 * KPerBlock;
|
||||
} while(k_block_data_begin < K - 2 * KPerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
|
||||
I1,
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
|
||||
|
||||
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
|
||||
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
|
||||
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc),
|
||||
Sequence<1,
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
|
||||
1,
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
make_multi_index(im0,
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
|
||||
in0,
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])}
|
||||
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_grid_buf,
|
||||
CGridStepHacks{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,461 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_GRIDWISE_GEMM_V2_HPP
|
||||
#define CK_GRIDWISE_GEMM_V2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "blockwise_gemm_dlops_v3.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
index_t KPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t EPerThread,
|
||||
typename ABlockTransferThreadSliceLengths_E_K,
|
||||
typename ABlockTransferThreadClusterLengths_E_K,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGlobalStepHacks,
|
||||
typename BGlobalStepHacks,
|
||||
typename CGlobalStepHacks,
|
||||
typename AGlobalMoveSliceWindowStepHacks,
|
||||
typename BGlobalMoveSliceWindowStepHacks>
|
||||
struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto E = EPerBlock * 3 * 3;
|
||||
|
||||
constexpr auto max_lds_align =
|
||||
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return a_block_space_size * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
const BGlobalDesc& b_e_n_ho_wo_global_desc,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
const CGlobalDesc& c_k_n_ho_wo_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_global, a_e_k_global_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize());
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
|
||||
|
||||
constexpr auto E = EPerBlock * 3 * 3;
|
||||
|
||||
// const auto E = a_e_k_global_desc.GetLength(I0);
|
||||
const auto K = a_e_k_global_desc.GetLength(I1);
|
||||
|
||||
const auto N = b_e_n_ho_wo_global_desc.GetLength(I1);
|
||||
const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2);
|
||||
const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
// divide block work by [M, N]
|
||||
#if 0
|
||||
const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
|
||||
const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
|
||||
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / hwo_block_work_num;
|
||||
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
|
||||
|
||||
const index_t ho_block_work_id = hwo_block_work_id / wo_block_work_num;
|
||||
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
|
||||
#else
|
||||
// Hack: this force result into SGPR
|
||||
const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock);
|
||||
const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock);
|
||||
const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num;
|
||||
|
||||
const index_t k_block_work_id =
|
||||
__builtin_amdgcn_readfirstlane(get_block_1d_id() / hwo_block_work_num);
|
||||
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
|
||||
|
||||
const index_t ho_block_work_id =
|
||||
__builtin_amdgcn_readfirstlane(hwo_block_work_id / wo_block_work_num);
|
||||
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
|
||||
#endif
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align =
|
||||
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_e_k_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
|
||||
|
||||
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_e_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_e_k_block_desc),
|
||||
decltype(b_e_n_ho_wo_block_desc),
|
||||
decltype(c_k_n_ho_wo_thread_desc),
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K>{};
|
||||
|
||||
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const auto k_thread_id = c_thread_mtx_index.k;
|
||||
const auto ho_thread_id = c_thread_mtx_index.h;
|
||||
const auto wo_thread_id = c_thread_mtx_index.w;
|
||||
|
||||
const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_on_global = ho_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_on_global = wo_block_work_id * WoPerBlock;
|
||||
|
||||
const index_t ho_thread_data_on_global =
|
||||
ho_block_data_on_global + ho_thread_id * HoPerThread;
|
||||
const index_t wo_thread_data_on_global =
|
||||
wo_block_data_on_global + wo_thread_id * WoPerThread;
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<E, KPerBlock>,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_e_k_global_desc),
|
||||
decltype(a_e_k_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(a_e_k_global_desc,
|
||||
make_multi_index(0, k_block_data_on_global),
|
||||
a_e_k_desc,
|
||||
make_multi_index(0, 0));
|
||||
|
||||
constexpr auto b_e_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
|
||||
|
||||
auto b_threadwise_transfer =
|
||||
ThreadwiseTensorSliceTransfer_v2<FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_e_n_ho_wo_global_desc),
|
||||
decltype(b_e_n_ho_wo_thread_desc),
|
||||
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
1,
|
||||
true>(
|
||||
b_e_n_ho_wo_global_desc,
|
||||
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_shared_block, a_e_k_desc.GetElementSpaceSize());
|
||||
|
||||
// register allocation for output
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
FloatAcc,
|
||||
c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
c_thread_buf;
|
||||
|
||||
// initialize output thread tensor
|
||||
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_k_n_ho_wo_thread_desc),
|
||||
Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
|
||||
.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
|
||||
|
||||
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{};
|
||||
constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{};
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
|
||||
BGlobalMoveSliceWindowStepHacks{};
|
||||
|
||||
// double regsiter buffer for b
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
FloatAB,
|
||||
b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
b_thread_even_buf, b_thread_odd_buf;
|
||||
|
||||
// LDS double buffer: preload data
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks);
|
||||
|
||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||
b_global_buf,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_even_buf,
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t e_block_data_begin = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
|
||||
b_thread_slice_copy_step);
|
||||
|
||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||
b_global_buf,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_odd_buf,
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
|
||||
|
||||
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
|
||||
|
||||
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
|
||||
b_thread_slice_copy_step);
|
||||
|
||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||
b_global_buf,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_even_buf,
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
|
||||
|
||||
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
|
||||
|
||||
e_block_data_begin += 2 * EPerBlock;
|
||||
|
||||
} while(e_block_data_begin < E - 2 * EPerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
|
||||
b_thread_slice_copy_step);
|
||||
|
||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||
b_global_buf,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_odd_buf,
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
|
||||
|
||||
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{};
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + k_thread_id * KPerThread;
|
||||
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_k_n_ho_wo_thread_desc),
|
||||
decltype(c_k_n_ho_wo_global_desc),
|
||||
Sequence<KPerThread, 1, HoPerThread, WoPerThread>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>(
|
||||
c_k_n_ho_wo_global_desc,
|
||||
make_multi_index(
|
||||
k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global))
|
||||
.Run(c_k_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_k_n_ho_wo_global_desc,
|
||||
c_global_buf,
|
||||
c_k_n_ho_wo_global_tensor_step_hacks);
|
||||
}
|
||||
}
|
||||
|
||||
// pass tensor descriptor by reference
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
const BGlobalDesc& b_e_n_ho_wo_global_desc,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
const CGlobalDesc& c_k_n_ho_wo_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>) const
|
||||
{
|
||||
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
Run(a_e_k_global_desc,
|
||||
p_a_global,
|
||||
b_e_n_ho_wo_global_desc,
|
||||
p_b_global,
|
||||
c_k_n_ho_wo_global_desc,
|
||||
p_c_global,
|
||||
p_shared_block,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
|
||||
// pass tensor descriptors by their pointers
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ void Run(const AGlobalDesc* p_a_e_k_global_desc,
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>) const
|
||||
{
|
||||
const auto a_e_k_global_desc = *p_a_e_k_global_desc;
|
||||
const auto b_e_n_ho_wo_global_desc = *p_b_e_n_ho_wo_global_desc;
|
||||
const auto c_k_n_ho_wo_global_desc = *p_c_k_n_ho_wo_global_desc;
|
||||
|
||||
Run(a_e_k_global_desc,
|
||||
p_a_global,
|
||||
b_e_n_ho_wo_global_desc,
|
||||
p_b_global,
|
||||
c_k_n_ho_wo_global_desc,
|
||||
p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
|
||||
// pass tensor descriptors by void*
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ void Run(const void* p_a_e_k_global_desc,
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
const void* p_b_e_n_ho_wo_global_desc,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
const void* p_c_k_n_ho_wo_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>) const
|
||||
{
|
||||
const auto a_e_k_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_e_k_global_desc);
|
||||
const auto b_e_n_ho_wo_global_desc =
|
||||
*reinterpret_cast<const BGlobalDesc*>(p_b_e_n_ho_wo_global_desc);
|
||||
const auto c_k_n_ho_wo_global_desc =
|
||||
*reinterpret_cast<const CGlobalDesc*>(p_c_k_n_ho_wo_global_desc);
|
||||
|
||||
Run(a_e_k_global_desc,
|
||||
p_a_global,
|
||||
b_e_n_ho_wo_global_desc,
|
||||
p_b_global,
|
||||
c_k_n_ho_wo_global_desc,
|
||||
p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,886 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP
|
||||
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "static_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
namespace detail {
|
||||
// TODO: How to fix this? It uses an struct instead of lambda because lambda
|
||||
// doesn't have constructor
|
||||
template <index_t SrcVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstVectorDim,
|
||||
index_t DstScalarPerVector>
|
||||
struct lambda_scalar_per_access_for_src_and_dst
|
||||
{
|
||||
__host__ __device__ constexpr auto operator()(index_t i) const
|
||||
{
|
||||
if(i == SrcVectorDim && i == DstVectorDim)
|
||||
{
|
||||
return math::lcm(SrcScalarPerVector, DstScalarPerVector);
|
||||
}
|
||||
else if(i == SrcVectorDim)
|
||||
{
|
||||
return SrcScalarPerVector;
|
||||
}
|
||||
else if(i == DstVectorDim)
|
||||
{
|
||||
return DstScalarPerVector;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Assume:
|
||||
// 1. src_desc and dst_desc are not known at compile-time
|
||||
// 2. SrcBuffer and DstBuffer are DynamicBuffer
|
||||
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
|
||||
// 4. Use thread buffer
|
||||
template <typename SliceLengths,
|
||||
typename SrcElementwiseOperation,
|
||||
typename DstElementwiseOperation,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename Dst0Desc,
|
||||
typename Dst1Desc,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
|
||||
// RunRead(), will be fused with MoveSrcSliceWindow to
|
||||
// save addr computation
|
||||
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
|
||||
// RunWrite(), will be fused with MoveDstSliceWindow to
|
||||
// save addr computation
|
||||
struct ThreadwiseTensorSliceTransfer_v3r3
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
using Dst0Coord = decltype(make_tensor_coordinate(Dst0Desc{}, Index{}));
|
||||
using Dst1Coord = decltype(make_tensor_coordinate(Dst1Desc{}, Index{}));
|
||||
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
using Dst0CoordStep = decltype(make_tensor_coordinate_step(Dst0Desc{}, Index{}));
|
||||
using Dst1CoordStep = decltype(make_tensor_coordinate_step(Dst1Desc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r3(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin,
|
||||
const SrcElementwiseOperation& src_element_op,
|
||||
const DstDesc& dst_desc,
|
||||
const Dst0Desc& dst0_desc,
|
||||
const Dst1Desc& dst1_desc,
|
||||
const Index& dst_slice_origin,
|
||||
const DstElementwiseOperation& dst_element_op)
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
|
||||
dst0_coord_(make_tensor_coordinate(dst0_desc, dst_slice_origin)),
|
||||
dst1_coord_(make_tensor_coordinate(dst1_desc, dst_slice_origin)),
|
||||
src_element_op_(src_element_op),
|
||||
dst_element_op_(dst_element_op)
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
{
|
||||
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc,
|
||||
const Dst0Desc& dst0_desc,
|
||||
const Dst1Desc& dst1_desc,
|
||||
const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
dst0_coord_ = make_tensor_coordinate(dst0_desc, dst_slice_origin_idx);
|
||||
dst1_coord_ = make_tensor_coordinate(dst1_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcBuffer>
|
||||
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
|
||||
{
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
|
||||
"wrong! SrcBuffer and SrcData data type are inconsistent");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// make forward steps
|
||||
const auto src_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(src_desc, forward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps
|
||||
const auto src_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(src_desc, backward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
|
||||
: ordered_src_access_lengths[i] - 1 -
|
||||
ordered_src_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
|
||||
constexpr auto src_data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
|
||||
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
|
||||
|
||||
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
|
||||
using src_vector_t = typename src_vector_type::type;
|
||||
|
||||
// copy data from src_buf into src_vector_container
|
||||
auto src_vector_container = src_vector_type{
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
|
||||
|
||||
// apply SrcElementwiseOperation on src_vector_container
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
src_vector_container.template AsType<SrcData>()(i) =
|
||||
src_element_op_(src_vector_container.template AsType<SrcData>()[i]);
|
||||
});
|
||||
|
||||
// copy data from src_vector_container into src_thread_scratch_
|
||||
src_thread_scratch_.template SetAsType<src_vector_t>(
|
||||
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]);
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &=
|
||||
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// move src coord
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move src coordinate back to slice origin (or not)
|
||||
if constexpr(SrcResetCoordinateAfterRun)
|
||||
{
|
||||
const auto src_reset_step =
|
||||
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void TransferDataFromSrcThreadScratchToDstThreadScratch()
|
||||
{
|
||||
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
// convert from SrcData to DstData here
|
||||
dst_thread_scratch_(idx) = type_convert<DstData>(src_thread_scratch_[idx]);
|
||||
});
|
||||
#else
|
||||
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
|
||||
// TODO make this logic more generic for more sub-dword datatype
|
||||
if constexpr(SrcVectorDim != DstVectorDim &&
|
||||
is_same<half_t, remove_cvref_t<SrcData>>::value &&
|
||||
is_same<half_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0)
|
||||
{
|
||||
// each transpose does
|
||||
// DstScalarPerVector # of src vectors in src_thread_scratch_
|
||||
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
|
||||
constexpr index_t num_src_vector = Number<DstScalarPerVector>{};
|
||||
constexpr index_t num_dst_vector = Number<SrcScalarPerVector>{};
|
||||
|
||||
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
|
||||
// TODO: make this logic generic for all scenario
|
||||
static_assert(SrcVectorDim != DstVectorDim, "wrong");
|
||||
|
||||
constexpr auto src_scalar_step_in_vector = generate_sequence(
|
||||
detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_scalar_step_in_vector = generate_sequence(
|
||||
detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
|
||||
SrcScalarPerVector,
|
||||
DstVectorDim,
|
||||
DstScalarPerVector>{},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
|
||||
|
||||
static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
|
||||
constexpr auto data_idx = access_idx * scalar_per_access;
|
||||
|
||||
constexpr auto data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
|
||||
|
||||
// TODO type_convert is not used yet!!!!!
|
||||
using src_vector_t = vector_type_maker_t<SrcData, SrcScalarPerVector>;
|
||||
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
|
||||
|
||||
// get DstScalarPerVector # of read-only references to src vectors from
|
||||
// src_thread_scratch_
|
||||
const auto src_vector_refs = generate_tie(
|
||||
[&](auto i) -> const src_vector_t& {
|
||||
// i increment corresponds to movement in DstVectorDim
|
||||
return src_thread_scratch_.GetVectorTypeReference(
|
||||
data_idx_seq + i * dst_scalar_step_in_vector);
|
||||
},
|
||||
Number<num_src_vector>{});
|
||||
|
||||
// get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
|
||||
auto dst_vector_refs = generate_tie(
|
||||
[&](auto i) -> dst_vector_t& {
|
||||
// i increment corresponds to movement in SrcVectorDim
|
||||
return dst_thread_scratch_.GetVectorTypeReference(
|
||||
data_idx_seq + i * src_scalar_step_in_vector);
|
||||
},
|
||||
Number<num_dst_vector>{});
|
||||
|
||||
// do data transpose
|
||||
// TODO type_convert is not used yet!!!!!
|
||||
transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}(
|
||||
src_vector_refs, dst_vector_refs);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
// convert from SrcData to DstData here
|
||||
dst_thread_scratch_(idx) = type_convert<DstData>(src_thread_scratch_[idx]);
|
||||
});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DstBuffer, typename Dst0Buffer, typename Dst1Buffer>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
const Dst0Desc& dst0_desc,
|
||||
const Dst0Buffer& dst0_buf,
|
||||
const Dst1Desc& dst1_desc,
|
||||
const Dst1Buffer& dst1_buf)
|
||||
{
|
||||
// if there is transpose, it's done here
|
||||
// TODO move this elsewhere
|
||||
TransferDataFromSrcThreadScratchToDstThreadScratch();
|
||||
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
|
||||
"wrong! SrcBuffer or DstBuffer data type is wrong");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
// src scalar per access on each dim
|
||||
// TODO: don't use this
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// make forward steps
|
||||
const auto dst_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(dst_desc, forward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make forward steps: dst0
|
||||
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
|
||||
// DstScalarPerVector
|
||||
// TODO: fix this
|
||||
const auto dst0_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(dst0_desc, forward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make forward steps: dst1
|
||||
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
|
||||
// DstScalarPerVector
|
||||
// TODO: fix this
|
||||
const auto dst1_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(dst1_desc, forward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps
|
||||
const auto dst_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(dst_desc, backward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps: dst0
|
||||
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
|
||||
// DstScalarPerVector
|
||||
// TODO: fix this
|
||||
const auto dst0_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(dst0_desc, backward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps: dst1
|
||||
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
|
||||
// DstScalarPerVector
|
||||
// TODO: fix this
|
||||
const auto dst1_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(dst1_desc, backward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
|
||||
: ordered_dst_access_lengths[i] - 1 -
|
||||
ordered_dst_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
|
||||
constexpr auto dst_data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>;
|
||||
using dst_vector_t = typename dst_vector_type::type;
|
||||
|
||||
// copy data from dst_thread_scratch_ into dst_vector_container
|
||||
auto dst_vector_container = dst_vector_type{
|
||||
dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx_seq)};
|
||||
|
||||
// apply DstElementwiseOperation on dst_vector_container
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
dst_vector_container.template AsType<DstData>()(i) =
|
||||
dst_element_op_(dst_vector_container.template AsType<DstData>()[i]);
|
||||
});
|
||||
|
||||
// copy data from dst_vector_container to dst_buf
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &=
|
||||
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// move dst coord
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
// TODO: BUG: should start at 1
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_lengths[I0] - 1;
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index after last iteration in RunRead(), if it has not being reset by
|
||||
// RunRead()
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_src_data_step = [&]() {
|
||||
Index reset_src_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
|
||||
|
||||
return reset_src_data_step_;
|
||||
}();
|
||||
|
||||
return reset_src_data_step;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_lengths[I0] - 1;
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
|
||||
// RunWrite()
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_dst_data_step = [&]() {
|
||||
Index reset_dst_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
|
||||
|
||||
return reset_dst_data_step_;
|
||||
}();
|
||||
|
||||
return reset_dst_data_step;
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx)
|
||||
{
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx)
|
||||
{
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||
const Dst0Desc dst0_desc,
|
||||
const Dst1Desc dst1_desc,
|
||||
const Index& dst_slice_origin_step_idx)
|
||||
{
|
||||
// if dst coord was not reset by RunWrite(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
move_tensor_coordinate(dst0_desc, dst0_coord_, adjusted_step);
|
||||
move_tensor_coordinate(dst1_desc, dst1_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
|
||||
{
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_access_lengths_and_vector_length = container_push_back(
|
||||
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
|
||||
|
||||
// 1st stage of transforms
|
||||
constexpr auto desc0 =
|
||||
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
|
||||
|
||||
// 2nd stage of transforms
|
||||
constexpr auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return make_merge_transform_v3_division_mod(
|
||||
make_tuple(src_access_lengths_and_vector_length[i],
|
||||
src_access_lengths_and_vector_length[Number<nDim>{}]));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto low_dim_idss = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return Sequence<i.value, nDim>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<i.value>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstThreadScratchDescriptor()
|
||||
{
|
||||
// 1st stage of transforms
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
|
||||
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
|
||||
|
||||
constexpr auto desc0 =
|
||||
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
|
||||
|
||||
// 2nd stage of transforms
|
||||
constexpr auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == DstVectorDim)
|
||||
{
|
||||
return make_merge_transform_v3_division_mod(
|
||||
make_tuple(dst_access_lengths_and_vector_length[i],
|
||||
dst_access_lengths_and_vector_length[Number<nDim>{}]));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto low_dim_idss = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == DstVectorDim)
|
||||
{
|
||||
return Sequence<i.value, nDim>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<i.value>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
|
||||
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
|
||||
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
SrcData,
|
||||
SrcScalarPerVector,
|
||||
decltype(src_thread_scratch_desc_),
|
||||
true>
|
||||
src_thread_scratch_;
|
||||
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
DstScalarPerVector,
|
||||
decltype(dst_thread_scratch_desc_),
|
||||
true>
|
||||
dst_thread_scratch_;
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
const SrcElementwiseOperation src_element_op_;
|
||||
const DstElementwiseOperation dst_element_op_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,14 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_AMD_LLVM_INTRINSIC_HPP
|
||||
#define CK_AMD_LLVM_INTRINSIC_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
__device__ int32_t llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.readfirstlane");
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,25 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_PRINT_HPP
|
||||
#define CK_PRINT_HPP
|
||||
|
||||
#include "array.hpp"
|
||||
#include "statically_indexed_array.hpp"
|
||||
#include "container_helper.hpp"
|
||||
#include "sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void print_array(const char* s, T a)
|
||||
{
|
||||
constexpr index_t nsize = a.Size();
|
||||
|
||||
printf("%s size %d, {", s, nsize);
|
||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
|
||||
printf("}\n");
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,136 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename C0DataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct ReferenceGemmBias2D : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<BDataType>& b_k_n,
|
||||
const Tensor<C0DataType>& c0_m_n,
|
||||
Tensor<CDataType>& c_m_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: a_m_k_{a_m_k},
|
||||
b_k_n_{b_k_n},
|
||||
c0_m_n_{c0_m_n},
|
||||
c_m_n_{c_m_n},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_m_k_;
|
||||
const Tensor<BDataType>& b_k_n_;
|
||||
const Tensor<CDataType>& c0_m_n_;
|
||||
Tensor<CDataType>& c_m_n_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceGemmBias2D::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
|
||||
|
||||
AccDataType a = 0;
|
||||
AccDataType b = 0;
|
||||
AccDataType acc = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
arg.a_element_op_(a, ck::type_convert<AccDataType>(arg.a_m_k_(m, k)));
|
||||
arg.b_element_op_(b, ck::type_convert<AccDataType>(arg.b_k_n_(k, n)));
|
||||
acc += a * b;
|
||||
}
|
||||
|
||||
CDataType cast_acc = static_cast<CDataType>(acc);
|
||||
arg.c_element_op_(arg.c_m_n_(m, n), cast_acc, arg.c0_m_n_(m, n));
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(
|
||||
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<BDataType>& b_k_n,
|
||||
const Tensor<C0DataType>& c0_m_n,
|
||||
Tensor<CDataType>& c_m_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{a_m_k, b_k_n, c0_m_n, c_m_n, a_element_op, b_element_op, c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceGemmBias2D"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,140 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct ReferenceGemmBiasActivation : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<BDataType>& b_k_n,
|
||||
Tensor<CDataType>& c_m_n,
|
||||
const Tensor<CDataType>& c0_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: a_m_k_{a_m_k},
|
||||
b_k_n_{b_k_n},
|
||||
c_m_n_{c_m_n},
|
||||
c0_n_{c0_n},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_m_k_;
|
||||
const Tensor<BDataType>& b_k_n_;
|
||||
Tensor<CDataType>& c_m_n_;
|
||||
const Tensor<CDataType>& c0_n_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceGemmBiasActivation::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
|
||||
|
||||
float v_acc = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
float v_a;
|
||||
float v_b;
|
||||
|
||||
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k)));
|
||||
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
float v_c;
|
||||
|
||||
arg.c_element_op_(v_c, v_acc, static_cast<float>(arg.c0_n_(n)));
|
||||
|
||||
arg.c_m_n_(m, n) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(
|
||||
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<BDataType>& b_k_n,
|
||||
Tensor<CDataType>& c_m_n,
|
||||
const Tensor<CDataType>& c0_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{a_m_k, b_k_n, c_m_n, c0_n, a_element_op, b_element_op, c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceGemmBiasActivation"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,148 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct ReferenceGemmBiasActivationAdd : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<BDataType>& b_k_n,
|
||||
Tensor<CDataType>& c_m_n,
|
||||
const Tensor<CDataType>& c0_n,
|
||||
const Tensor<CDataType>& c1_m_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: a_m_k_{a_m_k},
|
||||
b_k_n_{b_k_n},
|
||||
c_m_n_{c_m_n},
|
||||
c0_n_{c0_n},
|
||||
c1_m_n_{c1_m_n},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_m_k_;
|
||||
const Tensor<BDataType>& b_k_n_;
|
||||
Tensor<CDataType>& c_m_n_;
|
||||
const Tensor<CDataType>& c0_n_;
|
||||
const Tensor<CDataType>& c1_m_n_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceGemmBiasActivationAdd::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
|
||||
|
||||
float v_acc = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
float v_a;
|
||||
float v_b;
|
||||
|
||||
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k)));
|
||||
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
float v_c;
|
||||
|
||||
arg.c_element_op_(v_c,
|
||||
v_acc,
|
||||
static_cast<float>(arg.c0_n_(n)),
|
||||
static_cast<float>(arg.c1_m_n_(m, n)));
|
||||
|
||||
arg.c_m_n_(m, n) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(
|
||||
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<BDataType>& b_k_n,
|
||||
Tensor<CDataType>& c_m_n,
|
||||
const Tensor<CDataType>& c0_n,
|
||||
const Tensor<CDataType>& c1_m_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{
|
||||
a_m_k, b_k_n, c_m_n, c0_n, c1_m_n, a_element_op, b_element_op, c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceGemmBiasActivationAdd"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
|
||||
|
||||
@@ -3,10 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
|
||||
|
||||
@@ -3,10 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp"
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include <vector>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp"
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include <vector>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
|
||||
|
||||
@@ -3,10 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
|
||||
@@ -3,10 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
|
||||
|
||||
@@ -3,10 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "host_tensor.hpp"
|
||||
#include "conv_common.hpp"
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void host_conv_nchw_kcyx_nkhw(const Tensor<TIn>& in,
|
||||
const Tensor<TWei>& wei,
|
||||
Tensor<TOut>& out,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads&)
|
||||
{
|
||||
constexpr auto I0 = ck::Number<0>{};
|
||||
constexpr auto I1 = ck::Number<1>{};
|
||||
|
||||
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
float v = 0;
|
||||
for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c)
|
||||
{
|
||||
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[3])
|
||||
{
|
||||
v += ck::type_convert<float>(in(n, c, hi, wi)) *
|
||||
ck::type_convert<float>(wei(k, c, y, x));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out(n, k, ho, wo) = ck::type_convert<TOut>(v);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void host_conv3d_ndhwc_kzyxc_ndhwk(const Tensor<TIn>& in,
|
||||
const Tensor<TWei>& wei,
|
||||
Tensor<TOut>& out,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads&)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
const auto Di = in.mDesc.GetLengths()[1];
|
||||
const auto Hi = in.mDesc.GetLengths()[2];
|
||||
const auto Wi = in.mDesc.GetLengths()[3];
|
||||
const auto Z = wei.mDesc.GetLengths()[1];
|
||||
const auto Y = wei.mDesc.GetLengths()[2];
|
||||
const auto X = wei.mDesc.GetLengths()[3];
|
||||
const auto C = wei.mDesc.GetLengths()[4];
|
||||
|
||||
auto f_ndhwc = [&](auto n, auto do_tmp, auto ho_tmp, auto wo_tmp, auto k) {
|
||||
// do__ must be converted to signed integer, otherwise zmin might be wrong in cases
|
||||
// negative values.
|
||||
const int do_ = static_cast<int>(do_tmp);
|
||||
const int ho = static_cast<int>(ho_tmp);
|
||||
const int wo = static_cast<int>(wo_tmp);
|
||||
const int zmin =
|
||||
std::max(0,
|
||||
(in_left_pads[I0] - do_ * conv_strides[I0] + conv_dilations[I0] - 1) /
|
||||
conv_dilations[I0]);
|
||||
const int ymin =
|
||||
std::max(0,
|
||||
(in_left_pads[I1] - ho * conv_strides[I1] + conv_dilations[I1] - 1) /
|
||||
conv_dilations[I1]);
|
||||
const int xmin =
|
||||
std::max(0,
|
||||
(in_left_pads[I2] - wo * conv_strides[I2] + conv_dilations[I2] - 1) /
|
||||
conv_dilations[I2]);
|
||||
const int zmax =
|
||||
std::min(Z, (in_left_pads[I0] - do_ * conv_strides[I0] + Di) / conv_dilations[I0]);
|
||||
const int ymax =
|
||||
std::min(Y, (in_left_pads[I1] - ho * conv_strides[I1] + Hi) / conv_dilations[I1]);
|
||||
const int xmax =
|
||||
std::min(X, (in_left_pads[I2] - wo * conv_strides[I2] + Wi) / conv_dilations[I2]);
|
||||
const int di_min = do_ * conv_strides[I0] + zmin * conv_dilations[I0] - in_left_pads[I0];
|
||||
const int hi_min = ho * conv_strides[I1] + ymin * conv_dilations[I1] - in_left_pads[I1];
|
||||
const int wi_min = wo * conv_strides[I2] + xmin * conv_dilations[I2] - in_left_pads[I2];
|
||||
|
||||
double v = 0;
|
||||
|
||||
const TIn* in_n = in.mData.data() + n * Di * Hi * Wi * C;
|
||||
const TWei* wei_k = wei.mData.data() + k * Z * Y * X * C;
|
||||
|
||||
int di = di_min;
|
||||
for(int z = zmin; z < zmax; ++z, di += conv_dilations[I0])
|
||||
{
|
||||
const TIn* in_n_di = in_n + di * Hi * Wi * C;
|
||||
const TWei* wei_k_z = wei_k + z * Y * X * C;
|
||||
int hi = hi_min;
|
||||
|
||||
for(int y = ymin; y < ymax; ++y, hi += conv_dilations[I1])
|
||||
{
|
||||
const TIn* in_n_di_hi = in_n_di + hi * Wi * C;
|
||||
const TWei* wei_k_z_y = wei_k_z + y * X * C;
|
||||
int wi = wi_min;
|
||||
|
||||
for(int x = xmin; x < xmax; ++x, wi += conv_dilations[I2])
|
||||
{
|
||||
const TIn* in_n_di_hi_wi = in_n_di_hi + wi * C;
|
||||
const TWei* wei_k_z_y_x = wei_k_z_y + x * C;
|
||||
|
||||
for(int c = 0; c < C; ++c)
|
||||
{
|
||||
v += static_cast<const double>(in_n_di_hi_wi[c]) *
|
||||
static_cast<const double>(wei_k_z_y_x[c]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out(n, do_, ho, wo, k) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ndhwc,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3],
|
||||
out.mDesc.GetLengths()[4])(std::thread::hardware_concurrency() - 4);
|
||||
}
|
||||
@@ -1,249 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/utility/functional2.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace utils {
|
||||
|
||||
struct ProfileBestConfig
|
||||
{
|
||||
std::string best_op_name;
|
||||
float best_avg_time = std::numeric_limits<float>::max();
|
||||
float best_tflops = std::numeric_limits<float>::max();
|
||||
float best_gb_per_sec = std::numeric_limits<float>::max();
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief This class describes an operation instance(s).
|
||||
*
|
||||
* Op instance defines a particular specializations of operator
|
||||
* template. Thanks to this specific input/output data types, data
|
||||
* layouts and modifying elementwise operations it is able to create
|
||||
* it's input/output tensors, provide pointers to instances which
|
||||
* can execute it and all operation specific parameters.
|
||||
*/
|
||||
template <typename OutDataType, typename... InArgTypes>
|
||||
class OpInstance
|
||||
{
|
||||
public:
|
||||
template <typename T>
|
||||
using TensorPtr = std::unique_ptr<Tensor<T>>;
|
||||
using InTensorsTuple = std::tuple<TensorPtr<InArgTypes>...>;
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
using DeviceBuffers = std::vector<DeviceMemPtr>;
|
||||
|
||||
OpInstance() = default;
|
||||
OpInstance(const OpInstance&) = default;
|
||||
OpInstance& operator=(const OpInstance&) = default;
|
||||
virtual ~OpInstance(){};
|
||||
|
||||
virtual InTensorsTuple GetInputTensors() const = 0;
|
||||
virtual TensorPtr<OutDataType> GetOutputTensor() const = 0;
|
||||
virtual std::unique_ptr<tensor_operation::device::BaseInvoker>
|
||||
MakeInvokerPointer(tensor_operation::device::BaseOperator*) const = 0;
|
||||
virtual std::unique_ptr<tensor_operation::device::BaseArgument>
|
||||
MakeArgumentPointer(tensor_operation::device::BaseOperator*,
|
||||
const DeviceBuffers&,
|
||||
const DeviceMemPtr&) const = 0;
|
||||
virtual std::size_t GetFlops() const = 0;
|
||||
virtual std::size_t GetBtype() const = 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief A generic operation instance run engine.
|
||||
*/
|
||||
template <typename OutDataType, typename... InArgTypes>
|
||||
class OpInstanceRunEngine
|
||||
{
|
||||
public:
|
||||
using OpInstanceT = OpInstance<InArgTypes..., OutDataType>;
|
||||
template <typename T>
|
||||
using TensorPtr = std::unique_ptr<Tensor<T>>;
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
using InTensorsTuple = std::tuple<TensorPtr<InArgTypes>...>;
|
||||
using DeviceBuffers = std::vector<DeviceMemPtr>;
|
||||
using InArgsTypesTuple = std::tuple<InArgTypes...>;
|
||||
|
||||
OpInstanceRunEngine() = delete;
|
||||
|
||||
template <typename ReferenceOp = std::function<void()>>
|
||||
OpInstanceRunEngine(const OpInstanceT& op_instance,
|
||||
const ReferenceOp& reference_op = ReferenceOp{},
|
||||
bool do_verification = true)
|
||||
: op_instance_{op_instance}
|
||||
{
|
||||
in_tensors_ = op_instance_.GetInputTensors();
|
||||
out_tensor_ = op_instance_.GetOutputTensor();
|
||||
|
||||
if constexpr(std::is_invocable_v<ReferenceOp,
|
||||
const Tensor<InArgTypes>&...,
|
||||
Tensor<OutDataType>&>)
|
||||
{
|
||||
if(do_verification)
|
||||
{
|
||||
ref_output_ = op_instance_.GetOutputTensor();
|
||||
CallRefOpUnpackArgs(reference_op, std::make_index_sequence<kNInArgs_>{});
|
||||
}
|
||||
}
|
||||
AllocateDeviceInputTensors(std::make_index_sequence<kNInArgs_>{});
|
||||
out_device_buffer_ = std::make_unique<DeviceMem>(sizeof(OutDataType) *
|
||||
out_tensor_->mDesc.GetElementSpaceSize());
|
||||
out_device_buffer_->SetZero();
|
||||
}
|
||||
|
||||
virtual ~OpInstanceRunEngine(){};
|
||||
|
||||
template <typename OpInstancePtr>
|
||||
bool Test(const std::vector<OpInstancePtr>& op_ptrs)
|
||||
{
|
||||
bool res{true};
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto invoker = op_instance_.MakeInvokerPointer(op_ptr.get());
|
||||
auto argument = op_instance_.MakeArgumentPointer(
|
||||
op_ptr.get(), in_device_buffers_, out_device_buffer_);
|
||||
if(op_ptr->IsSupportedArgument(argument.get()))
|
||||
{
|
||||
std::cout << "Testing instance: " << op_ptr->GetTypeString() << std::endl;
|
||||
invoker->Run(argument.get());
|
||||
out_device_buffer_->FromDevice(out_tensor_->mData.data());
|
||||
if(!ref_output_)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"OpInstanceRunEngine::Test: Reference value not availabe."
|
||||
" You have to provide reference function.");
|
||||
}
|
||||
// TODO: enable flexible use of custom check_error functions
|
||||
bool inst_res = CheckErr(out_tensor_->mData, ref_output_->mData);
|
||||
std::cout << (inst_res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
res = res && inst_res;
|
||||
out_device_buffer_->SetZero();
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Given conv problem is not supported by instance: \n\t>>>>"
|
||||
<< op_ptr->GetTypeString() << std::endl;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename OpInstancePtr>
|
||||
ProfileBestConfig Profile(const std::vector<OpInstancePtr>& op_ptrs,
|
||||
bool time_kernel = false,
|
||||
bool do_verification = false,
|
||||
bool do_log = false)
|
||||
{
|
||||
ProfileBestConfig best_config;
|
||||
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto invoker = op_instance_.MakeInvokerPointer(op_ptr.get());
|
||||
auto argument = op_instance_.MakeArgumentPointer(
|
||||
op_ptr.get(), in_device_buffers_, out_device_buffer_);
|
||||
if(op_ptr->IsSupportedArgument(argument.get()))
|
||||
{
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
float avg_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flops = op_instance_.GetFlops();
|
||||
std::size_t num_btype = op_instance_.GetBtype();
|
||||
float tflops = static_cast<float>(flops) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(avg_time < best_config.best_avg_time)
|
||||
{
|
||||
best_config.best_op_name = op_name;
|
||||
best_config.best_tflops = tflops;
|
||||
best_config.best_gb_per_sec = gb_per_sec;
|
||||
best_config.best_avg_time = avg_time;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
out_device_buffer_->FromDevice(out_tensor_->mData.data());
|
||||
if(!ref_output_)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"OpInstanceRunEngine::Profile: Reference value not availabe."
|
||||
" You have to provide reference function.");
|
||||
}
|
||||
// TODO: enable flexible use of custom check_error functions
|
||||
CheckErr(out_tensor_->mData, ref_output_->mData);
|
||||
|
||||
if(do_log) {}
|
||||
}
|
||||
out_device_buffer_->SetZero();
|
||||
}
|
||||
}
|
||||
return best_config;
|
||||
}
|
||||
|
||||
void SetAtol(double a) { atol_ = a; }
|
||||
void SetRtol(double r) { rtol_ = r; }
|
||||
|
||||
private:
|
||||
template <typename F, std::size_t... Is>
|
||||
void CallRefOpUnpackArgs(const F& f, std::index_sequence<Is...>) const
|
||||
{
|
||||
f(*std::get<Is>(in_tensors_)..., *ref_output_);
|
||||
}
|
||||
|
||||
template <std::size_t... Is>
|
||||
void AllocateDeviceInputTensors(std::index_sequence<Is...>)
|
||||
{
|
||||
(AllocateDeviceInputTensorsImpl<Is>(), ...);
|
||||
}
|
||||
|
||||
template <std::size_t Index>
|
||||
void AllocateDeviceInputTensorsImpl()
|
||||
{
|
||||
const auto& ts = std::get<Index>(in_tensors_);
|
||||
in_device_buffers_
|
||||
.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(std::tuple_element_t<Index, InArgsTypesTuple>) *
|
||||
ts->mDesc.GetElementSpaceSize()))
|
||||
->ToDevice(ts->mData.data());
|
||||
}
|
||||
|
||||
static constexpr std::size_t kNInArgs_ = std::tuple_size_v<InTensorsTuple>;
|
||||
const OpInstanceT& op_instance_;
|
||||
double rtol_{1e-5};
|
||||
double atol_{1e-8};
|
||||
|
||||
InTensorsTuple in_tensors_;
|
||||
TensorPtr<OutDataType> out_tensor_;
|
||||
TensorPtr<OutDataType> ref_output_;
|
||||
|
||||
DeviceBuffers in_device_buffers_;
|
||||
DeviceMemPtr out_device_buffer_;
|
||||
|
||||
template <typename T>
|
||||
bool CheckErr(const std::vector<T>& dev_out, const std::vector<T>& ref_out) const
|
||||
{
|
||||
return ck::utils::check_err(dev_out, ref_out, "Error: incorrect results!", rtol_, atol_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace utils
|
||||
} // namespace ck
|
||||
@@ -1,77 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "profiler/data_type_enum.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <DataTypeEnum DataTypeEnum>
|
||||
struct get_datatype_from_enum;
|
||||
|
||||
template <>
|
||||
struct get_datatype_from_enum<DataTypeEnum::Int8>
|
||||
{
|
||||
using type = int8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_from_enum<DataTypeEnum::Int32>
|
||||
{
|
||||
using type = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_from_enum<DataTypeEnum::Half>
|
||||
{
|
||||
using type = half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_from_enum<DataTypeEnum::Float>
|
||||
{
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_from_enum<DataTypeEnum::Double>
|
||||
{
|
||||
using type = double;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct get_datatype_enum_from_type;
|
||||
|
||||
template <>
|
||||
struct get_datatype_enum_from_type<int8_t>
|
||||
{
|
||||
static constexpr DataTypeEnum value = DataTypeEnum::Int8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_enum_from_type<int32_t>
|
||||
{
|
||||
static constexpr DataTypeEnum value = DataTypeEnum::Int32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_enum_from_type<half_t>
|
||||
{
|
||||
static constexpr DataTypeEnum value = DataTypeEnum::Half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_enum_from_type<float>
|
||||
{
|
||||
static constexpr DataTypeEnum value = DataTypeEnum::Float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_enum_from_type<double>
|
||||
{
|
||||
static constexpr DataTypeEnum value = DataTypeEnum::Double;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -1,486 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/conv_util.hpp"
|
||||
#include "ck/library/host_tensor/device_memory.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using INT8 = int8_t;
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using DeviceConvBwdDataNoOpPtr =
|
||||
DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
|
||||
std::vector<DeviceConvBwdDataNoOpPtr>&);
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
using DeviceConvBwdDataNoOpPtr = ck::tensor_operation::device::instance::DeviceConvBwdDataNoOpPtr;
|
||||
|
||||
template <typename InLayout>
|
||||
HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{});
|
||||
}
|
||||
case 2: {
|
||||
return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{});
|
||||
}
|
||||
case 1: {
|
||||
return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{});
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename WeiLayout>
|
||||
HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{});
|
||||
}
|
||||
case 2: {
|
||||
return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{});
|
||||
}
|
||||
case 1: {
|
||||
return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{});
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename OutLayout>
|
||||
HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{});
|
||||
}
|
||||
case 2: {
|
||||
return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{});
|
||||
}
|
||||
case 1: {
|
||||
return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{});
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
void get_device_conv_bwd_data_op_ptr(
|
||||
InDataType, WeiDataType, OutDataType, std::vector<DeviceConvBwdDataNoOpPtr>&, int)
|
||||
{
|
||||
std::cout << "can not find device conv bwd data" << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
template <>
|
||||
void get_device_conv_bwd_data_op_ptr(
|
||||
F32, F32, F32, std::vector<DeviceConvBwdDataNoOpPtr>& conv_ptrs, int num_dim_spatial)
|
||||
{
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs);
|
||||
break;
|
||||
case 2:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
|
||||
break;
|
||||
case 3:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs);
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
template <>
|
||||
void get_device_conv_bwd_data_op_ptr(
|
||||
F16, F16, F16, std::vector<DeviceConvBwdDataNoOpPtr>& conv_ptrs, int num_dim_spatial)
|
||||
{
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs);
|
||||
break;
|
||||
case 2:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
|
||||
break;
|
||||
case 3:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs);
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
template <>
|
||||
void get_device_conv_bwd_data_op_ptr(
|
||||
BF16, BF16, BF16, std::vector<DeviceConvBwdDataNoOpPtr>& conv_ptrs, int num_dim_spatial)
|
||||
{
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs);
|
||||
break;
|
||||
case 2:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
|
||||
break;
|
||||
case 3:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs);
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
template <>
|
||||
void get_device_conv_bwd_data_op_ptr(
|
||||
INT8, INT8, INT8, std::vector<DeviceConvBwdDataNoOpPtr>& conv_ptrs, int num_dim_spatial)
|
||||
{
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs);
|
||||
break;
|
||||
case 2:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
|
||||
break;
|
||||
case 3:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs);
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
{
|
||||
float max_diff = 1e-6;
|
||||
|
||||
for(std::size_t i = 0; i < ref.mData.size(); ++i)
|
||||
{
|
||||
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
|
||||
if(max_diff < diff)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
template <typename DataType>
|
||||
void show_data_nhwc_layout(Tensor<DataType>& nhwc)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int n = 0; n < ck::type_convert<int>(nhwc.mDesc.GetLengths()[0]); n++)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int hi = 0; hi < ck::type_convert<int>(nhwc.mDesc.GetLengths()[2]); hi++)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int wi = 0; wi < ck::type_convert<int>(nhwc.mDesc.GetLengths()[3]); wi++)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int c = 0; c < ck::type_convert<int>(nhwc.mDesc.GetLengths()[1]); c++)
|
||||
{
|
||||
std::cout << static_cast<float>(nhwc(n, c, hi, wi)) << " ";
|
||||
}
|
||||
std::cout << "]";
|
||||
}
|
||||
std::cout << "]";
|
||||
}
|
||||
std::cout << "]";
|
||||
}
|
||||
std::cout << "]";
|
||||
}
|
||||
|
||||
template <int NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
bool profile_convnd_bwd_data_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
const std::vector<ck::index_t>& input_spatial_lengths,
|
||||
const std::vector<ck::index_t>& filter_spatial_lengths,
|
||||
const std::vector<ck::index_t>& output_spatial_lengths,
|
||||
const std::vector<ck::index_t>& conv_filter_strides,
|
||||
const std::vector<ck::index_t>& conv_filter_dilations,
|
||||
const std::vector<ck::index_t>& input_left_pads,
|
||||
const std::vector<ck::index_t>& input_right_pads)
|
||||
{
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
const auto in_element_op = InElementOp{};
|
||||
const auto wei_element_op = WeiElementOp{};
|
||||
const auto out_element_op = OutElementOp{};
|
||||
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(N), static_cast<std::size_t>(C)};
|
||||
input_dims.insert(
|
||||
std::end(input_dims), std::begin(input_spatial_lengths), std::end(input_spatial_lengths));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(K), static_cast<std::size_t>(C)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(filter_spatial_lengths),
|
||||
std::end(filter_spatial_lengths));
|
||||
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(N), static_cast<std::size_t>(K)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
|
||||
Tensor<InDataType> input_host_result(
|
||||
get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial));
|
||||
Tensor<InDataType> input_device_result(
|
||||
get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial));
|
||||
Tensor<WeiDataType> weights(
|
||||
get_filters_host_tensor_descriptor<WeiLayout>(filter_dims, NDimSpatial));
|
||||
Tensor<OutDataType> output(
|
||||
get_output_host_ensor_descriptor<OutLayout>(output_dims, NDimSpatial));
|
||||
|
||||
std::cout << "input: " << input_host_result.mDesc << std::endl;
|
||||
std::cout << "weights: " << weights.mDesc << std::endl;
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
output.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
output.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||
weights.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * input_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace());
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace());
|
||||
|
||||
out_device_buf.ToDevice(output.mData.data());
|
||||
wei_device_buf.ToDevice(weights.mData.data());
|
||||
|
||||
// reset input to zero
|
||||
in_device_buf.SetZero();
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto RunReference = [&](auto& ref_conv) {
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_conv.MakeArgument(input_host_result,
|
||||
weights,
|
||||
output,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
ref_invoker.Run(ref_argument);
|
||||
};
|
||||
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
NDimSpatial>();
|
||||
RunReference(ref_conv);
|
||||
}
|
||||
|
||||
// add device Conv instances
|
||||
std::vector<DeviceConvBwdDataNoOpPtr> conv_ptrs;
|
||||
get_device_conv_bwd_data_op_ptr(
|
||||
InDataType{}, WeiDataType{}, OutDataType{}, conv_ptrs, NDimSpatial);
|
||||
|
||||
if(conv_ptrs.size() <= 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! no device Conv instance found");
|
||||
}
|
||||
|
||||
std::string best_conv_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device Conv instances
|
||||
bool success = true;
|
||||
for(auto& conv_ptr : conv_ptrs)
|
||||
{
|
||||
auto argument_ptr = conv_ptr->MakeArgumentPointer(
|
||||
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
auto invoker_ptr = conv_ptr->MakeInvokerPointer();
|
||||
|
||||
if(conv_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::string conv_name = conv_ptr->GetTypeString();
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop =
|
||||
ck::utils::conv::get_flops(N, C, K, filter_spatial_lengths, output_spatial_lengths);
|
||||
std::size_t num_btype =
|
||||
ck::utils::conv::get_btype<InDataType, WeiDataType, OutDataType>(
|
||||
N, C, K, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s" << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_conv_name = conv_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
in_device_buf.FromDevice(input_device_result.mData.data());
|
||||
|
||||
if(!check_out(input_host_result, input_device_result))
|
||||
{
|
||||
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;
|
||||
|
||||
success = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
success = ck::utils::check_err(input_host_result, input_device_result);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
std::cout << "in : ";
|
||||
show_data_nhwc_layout(output);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "wei: ";
|
||||
show_data_nhwc_layout(weights);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "out_host : ";
|
||||
show_data_nhwc_layout(input_host_result);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "out_device: ";
|
||||
show_data_nhwc_layout(input_device_result);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_conv_name << std::endl;
|
||||
return success;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
@@ -1,474 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/conv_util.hpp"
|
||||
#include "ck/library/host_tensor/device_memory.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using BF16 = ck::bhalf_t;
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using DeviceConvndBwdWeightNoOpPtr =
|
||||
DeviceConvBwdWeightPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(
|
||||
std::vector<DeviceConvndBwdWeightNoOpPtr>&);
|
||||
void add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(
|
||||
std::vector<DeviceConvndBwdWeightNoOpPtr>&);
|
||||
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
|
||||
std::vector<DeviceConvndBwdWeightNoOpPtr>&);
|
||||
|
||||
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(
|
||||
std::vector<DeviceConvndBwdWeightNoOpPtr>&);
|
||||
void add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(
|
||||
std::vector<DeviceConvndBwdWeightNoOpPtr>&);
|
||||
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
|
||||
std::vector<DeviceConvndBwdWeightNoOpPtr>&);
|
||||
|
||||
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instances(
|
||||
std::vector<DeviceConvndBwdWeightNoOpPtr>&);
|
||||
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instances(
|
||||
std::vector<DeviceConvndBwdWeightNoOpPtr>&);
|
||||
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
|
||||
std::vector<DeviceConvndBwdWeightNoOpPtr>&);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
using DeviceConvndBwdWeightNoOpPtr =
|
||||
ck::tensor_operation::device::instance::DeviceConvndBwdWeightNoOpPtr;
|
||||
|
||||
template <typename InLayout>
|
||||
HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{});
|
||||
}
|
||||
case 2: {
|
||||
return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{});
|
||||
}
|
||||
case 1: {
|
||||
return ck::utils::conv::get_host_tensor_descriptor(dims, InLayout{});
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename WeiLayout>
|
||||
HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{});
|
||||
}
|
||||
case 2: {
|
||||
return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{});
|
||||
}
|
||||
case 1: {
|
||||
return ck::utils::conv::get_host_tensor_descriptor(dims, WeiLayout{});
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutLayout>
|
||||
HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{});
|
||||
}
|
||||
case 2: {
|
||||
return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{});
|
||||
}
|
||||
case 1: {
|
||||
return ck::utils::conv::get_host_tensor_descriptor(dims, OutLayout{});
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
void get_device_conv_bwd_weight_op_ptr(
|
||||
InDataType, WeiDataType, OutDataType, std::vector<DeviceConvndBwdWeightNoOpPtr>&, int)
|
||||
{
|
||||
std::cout << "can not find device conv bwd weight" << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
template <>
|
||||
void get_device_conv_bwd_weight_op_ptr(
|
||||
F32, F32, F32, std::vector<DeviceConvndBwdWeightNoOpPtr>& conv_ptrs, int num_dim_spatial)
|
||||
{
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs);
|
||||
break;
|
||||
case 2:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
|
||||
break;
|
||||
case 3:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs);
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void get_device_conv_bwd_weight_op_ptr(
|
||||
F16, F16, F16, std::vector<DeviceConvndBwdWeightNoOpPtr>& conv_ptrs, int num_dim_spatial)
|
||||
{
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs);
|
||||
break;
|
||||
case 2:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_convnd_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
|
||||
break;
|
||||
case 3:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs);
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void get_device_conv_bwd_weight_op_ptr(
|
||||
BF16, BF16, BF16, std::vector<DeviceConvndBwdWeightNoOpPtr>& conv_ptrs, int num_dim_spatial)
|
||||
{
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 1:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs);
|
||||
break;
|
||||
case 2:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
|
||||
break;
|
||||
case 3:
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs);
|
||||
break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
void show_data_nhwc_layout(Tensor<DataType>& nhwc)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int n = 0; n < ck::type_convert<int>(nhwc.mDesc.GetLengths()[0]); n++)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int hi = 0; hi < ck::type_convert<int>(nhwc.mDesc.GetLengths()[2]); hi++)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int wi = 0; wi < ck::type_convert<int>(nhwc.mDesc.GetLengths()[3]); wi++)
|
||||
{
|
||||
std::cout << "[";
|
||||
for(int c = 0; c < ck::type_convert<int>(nhwc.mDesc.GetLengths()[1]); c++)
|
||||
{
|
||||
std::cout << static_cast<float>(nhwc(n, c, hi, wi)) << " ";
|
||||
}
|
||||
std::cout << "]";
|
||||
}
|
||||
std::cout << "]";
|
||||
}
|
||||
std::cout << "]";
|
||||
}
|
||||
std::cout << "]";
|
||||
}
|
||||
|
||||
template <int NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
bool profile_convnd_bwd_weight_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
ck::index_t split_k)
|
||||
{
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
const auto in_element_op = InElementOp{};
|
||||
const auto wei_element_op = WeiElementOp{};
|
||||
const auto out_element_op = OutElementOp{};
|
||||
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(N), static_cast<std::size_t>(C)};
|
||||
input_dims.insert(
|
||||
std::end(input_dims), std::begin(input_spatial_lengths), std::end(input_spatial_lengths));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(K), static_cast<std::size_t>(C)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(filter_spatial_lengths),
|
||||
std::end(filter_spatial_lengths));
|
||||
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(N), static_cast<std::size_t>(K)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
|
||||
Tensor<InDataType> input(get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial));
|
||||
Tensor<WeiDataType> weights_host_result(
|
||||
get_filters_host_tensor_descriptor<WeiLayout>(filter_dims, NDimSpatial));
|
||||
Tensor<WeiDataType> weights_device_result(
|
||||
get_filters_host_tensor_descriptor<WeiLayout>(filter_dims, NDimSpatial));
|
||||
Tensor<OutDataType> output(
|
||||
get_output_host_ensor_descriptor<OutLayout>(output_dims, NDimSpatial));
|
||||
|
||||
std::cout << "input: " << input.mDesc << std::endl;
|
||||
std::cout << "weights: " << weights_host_result.mDesc << std::endl;
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
input.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
|
||||
output.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
input.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||
output.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace());
|
||||
DeviceMem wei_device_buf(sizeof(WeiDataType) * weights_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace());
|
||||
|
||||
in_device_buf.ToDevice(input.mData.data());
|
||||
out_device_buf.ToDevice(output.mData.data());
|
||||
|
||||
// reset input to zero
|
||||
wei_device_buf.SetZero();
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto RunReference = [&](auto& ref_conv) {
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_conv.MakeArgument(input,
|
||||
weights_host_result,
|
||||
output,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
ref_invoker.Run(ref_argument);
|
||||
};
|
||||
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
NDimSpatial>();
|
||||
RunReference(ref_conv);
|
||||
}
|
||||
|
||||
// add device Conv instances
|
||||
std::vector<DeviceConvndBwdWeightNoOpPtr> conv_ptrs;
|
||||
get_device_conv_bwd_weight_op_ptr(
|
||||
InDataType{}, WeiDataType{}, OutDataType{}, conv_ptrs, NDimSpatial);
|
||||
|
||||
if(conv_ptrs.size() <= 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! no device Conv instance found");
|
||||
}
|
||||
|
||||
std::string best_conv_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device Conv instances
|
||||
bool success = true;
|
||||
for(auto& conv_ptr : conv_ptrs)
|
||||
{
|
||||
// using atomic, so need to reset input, setzero is done in invoker
|
||||
// if(split_k > 1)
|
||||
//{
|
||||
// wei_device_buf.SetZero();
|
||||
//}
|
||||
|
||||
auto argument_ptr = conv_ptr->MakeArgumentPointer(
|
||||
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op,
|
||||
split_k);
|
||||
|
||||
if(!conv_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::cout << "wrong! device_conv with the specified compilation parameters does "
|
||||
"not support this Conv problem"
|
||||
<< std::endl;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto invoker_ptr = conv_ptr->MakeInvokerPointer();
|
||||
std::string conv_name = conv_ptr->GetTypeString();
|
||||
float ave_time = 0;
|
||||
|
||||
if(std::is_same<InDataType, ck::bhalf_t>::value && split_k > 1)
|
||||
{
|
||||
// alloc work space
|
||||
size_t bwd_weight_workspace_size = conv_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
if(bwd_weight_workspace_size <= 0)
|
||||
{
|
||||
printf("wrong work space size\n");
|
||||
exit(1);
|
||||
}
|
||||
DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size);
|
||||
wei_work_space_device_buf.SetZero();
|
||||
conv_ptr->SetWorkSpacePointer(argument_ptr.get(),
|
||||
wei_work_space_device_buf.GetDeviceBuffer());
|
||||
ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
}
|
||||
|
||||
std::size_t flop =
|
||||
ck::utils::conv::get_flops(N, C, K, filter_spatial_lengths, output_spatial_lengths);
|
||||
std::size_t num_btype = ck::utils::conv::get_btype<InDataType, WeiDataType, OutDataType>(
|
||||
N, C, K, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s" << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_conv_name = conv_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
wei_device_buf.FromDevice(weights_device_result.mData.data());
|
||||
|
||||
success = ck::utils::check_err(weights_host_result, weights_device_result);
|
||||
|
||||
if(success == false)
|
||||
{
|
||||
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
std::cout << "in : ";
|
||||
show_data_nhwc_layout(output);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "wei: ";
|
||||
show_data_nhwc_layout(weights_host_result);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "out : ";
|
||||
show_data_nhwc_layout(input);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "wei_device: ";
|
||||
show_data_nhwc_layout(weights_device_result);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_conv_name << std::endl;
|
||||
return success;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user