mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
update implicit GEMM forward v4r4 to use gridwise gemm (#9)
* updated fwd v4r4 to use gridwise gemm * updated gridwise gemm api calls in bwd-data v1r1 and v2r1
This commit is contained in:
@@ -8,6 +8,9 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = C * Y * X
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = K
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
@@ -17,11 +20,11 @@ template <index_t GridSize,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
index_t EPerBlock,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
@@ -31,13 +34,15 @@ template <index_t GridSize,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmThreadGemmDataPerReadM,
|
||||
index_t GemmThreadGemmDataPerReadN,
|
||||
typename WeiBlockCopySubLengths_K_E,
|
||||
typename WeiBlockCopyClusterLengths_K_E,
|
||||
index_t WeiBlockCopyDataPerAccess_E,
|
||||
typename OutBlockCopySubLengths_K_B,
|
||||
typename OutBlockCopyClusterLengths_K_B,
|
||||
index_t OutBlockCopyDataPerAccess_B,
|
||||
index_t InThreadCopyDataPerAccess_B>
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmN,
|
||||
index_t GemmABlockCopyDstDataPerWrite_GemmN,
|
||||
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
index_t GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
|
||||
struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
{
|
||||
__device__ void Run(Float* __restrict__ p_in_global,
|
||||
@@ -49,8 +54,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
|
||||
@@ -73,14 +76,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t E = C * Y * X;
|
||||
constexpr index_t B = N * Ho * Wo;
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Wo == 1 || (ConvStrideW == 1 || InThreadCopyDataPerAccess_B == 1)) &&
|
||||
(X == 1 || ConvDilationW % InThreadCopyDataPerAccess_B == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
// TODO: this logic may not be correct for bwd-data
|
||||
static_assert(
|
||||
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
|
||||
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_n_k_howo_global_desc =
|
||||
@@ -99,8 +101,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
// input tensor
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
@@ -121,33 +124,43 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
|
||||
// GEMM: atomic add
|
||||
constexpr auto gridwise_gemm =
|
||||
GridwiseGemmTransposedANormalBNormalC_v1r1<GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_k_e_global_desc),
|
||||
decltype(out_k_b_global_desc),
|
||||
decltype(in_e_b_global_desc),
|
||||
InMemoryDataOperation::atomic_add,
|
||||
EPerBlock,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
WeiBlockCopySubLengths_K_E,
|
||||
WeiBlockCopyClusterLengths_K_E,
|
||||
WeiBlockCopyDataPerAccess_E,
|
||||
OutBlockCopySubLengths_K_B,
|
||||
OutBlockCopyClusterLengths_K_B,
|
||||
OutBlockCopyDataPerAccess_B,
|
||||
InThreadCopyDataPerAccess_B>{};
|
||||
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_k_e_global_desc),
|
||||
decltype(out_k_b_global_desc),
|
||||
decltype(in_e_b_global_desc),
|
||||
InMemoryDataOperation::atomic_add,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmABlockCopySrcDataPerRead_GemmN,
|
||||
GemmABlockCopyDstDataPerWrite_GemmN,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
|
||||
}
|
||||
|
||||
@@ -8,9 +8,9 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmK = K * Ydot * Xdot;
|
||||
// GemmM = C * Ytilda * Xtilda;
|
||||
// GemmN = N * Htilda * Wtilda;
|
||||
// GemmK = K * Ydot * Xdot;
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
@@ -34,14 +34,15 @@ template <index_t GridSize,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmThreadGemmDataPerReadM,
|
||||
index_t GemmThreadGemmDataPerReadN,
|
||||
typename GemmABlockCopySubLengths, // Gemm-K, Gemm-M
|
||||
typename GemmABlockCopyClusterLengths, // Gemm-K, Gemm-M
|
||||
index_t GemmABlockCopyDataPerAccess, // Gemm-M
|
||||
typename GemmBBlockCopySubLengths, // Gemm-K, Gemm-N
|
||||
typename GemmBBlockCopyClusterLengths, // Gemm-K, Gemm-N
|
||||
index_t GemmBBlockCopyDataPerAccess, // Gemm-N
|
||||
index_t GemmCThreadCopyDataPerAccess // Gemm-N
|
||||
>
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmM,
|
||||
index_t GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
index_t GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
|
||||
struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
{
|
||||
__device__ void Run(Float* __restrict__ p_in_global,
|
||||
@@ -71,10 +72,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDataPerAccess == 1)) &&
|
||||
(X == 1 || ConvDilationW % GemmCThreadCopyDataPerAccess == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
// TODO: this logic may not be correct for bwd-data
|
||||
static_assert(
|
||||
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
|
||||
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
@@ -172,33 +175,43 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
|
||||
// GEMM
|
||||
constexpr auto gridwise_gemm =
|
||||
GridwiseGemmTransposedANormalBNormalC_v1r1<GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
decltype(out_gemmk_gemmn_global_desc),
|
||||
decltype(in_gemmm_gemmn_global_desc),
|
||||
InMemoryDataOperation::none,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopySubLengths,
|
||||
GemmABlockCopyClusterLengths,
|
||||
GemmABlockCopyDataPerAccess,
|
||||
GemmBBlockCopySubLengths,
|
||||
GemmBBlockCopyClusterLengths,
|
||||
GemmBBlockCopyDataPerAccess,
|
||||
GemmCThreadCopyDataPerAccess>{};
|
||||
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
decltype(out_gemmk_gemmn_global_desc),
|
||||
decltype(in_gemmm_gemmn_global_desc),
|
||||
InMemoryDataOperation::none,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmThreadGemmDataPerReadM,
|
||||
index_t GemmThreadGemmDataPerReadN,
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmK,
|
||||
index_t GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
index_t GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
|
||||
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmBBlockCopySrcDataPerRead_GemmN == 1)) &&
|
||||
(X == 1 || ConvDilationW % GemmBBlockCopySrcDataPerRead_GemmN == 0) &&
|
||||
InLeftPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0 &&
|
||||
InRightPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0,
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower(
|
||||
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
|
||||
|
||||
// input tensor
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_e_b_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_k_b_global_desc =
|
||||
transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// GEMM
|
||||
constexpr auto gridwise_gemm =
|
||||
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(in_e_b_global_desc),
|
||||
decltype(out_k_b_global_desc),
|
||||
InMemoryDataOperation::none,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
GemmABlockCopySrcDataPerRead_GemmK,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,408 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
// B = merge(N, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
typename InBlockCopySubLengths_E_B,
|
||||
typename InBlockCopyClusterLengths_E_B,
|
||||
typename InBlockCopyThreadClusterArrangeOrder,
|
||||
typename InBlockCopySrcAccessOrder,
|
||||
typename InBlockCopyDstAccessOrder,
|
||||
index_t InBlockCopyDataPerAccess_B,
|
||||
typename WeiBlockCopySubLengths_E_K,
|
||||
typename WeiBlockCopyClusterLengths_E_K,
|
||||
typename WeiBlockCopyThreadClusterArrangeOrder,
|
||||
typename WeiBlockCopySrcAccessOrder,
|
||||
typename WeiBlockCopyDstAccessOrder,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K,
|
||||
index_t OutThreadCopyDataPerAccess_B>
|
||||
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_global_desc =
|
||||
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
|
||||
constexpr auto wei_k_c_y_x_global_desc =
|
||||
make_native_tensor_descriptor(WeiGlobalDesc::GetLengths(), WeiGlobalDesc::GetStrides());
|
||||
constexpr auto out_n_k_ho_wo_global_desc =
|
||||
make_native_tensor_descriptor(OutGlobalDesc::GetLengths(), OutGlobalDesc::GetStrides());
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t E = C * Y * X;
|
||||
constexpr index_t B = N * Ho * Wo;
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopyDataPerAccess_B == 1)) &&
|
||||
(X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_cluster_descriptor(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// global mem
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_e_b_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// LDS mem
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_e_b_block_desc =
|
||||
make_native_tensor_descriptor_packed(Sequence<EPerBlock, BPerBlock>{});
|
||||
|
||||
// input blockwise copy
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(in_e_b_global_desc),
|
||||
decltype(in_e_b_block_desc),
|
||||
decltype(in_e_b_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_B,
|
||||
InBlockCopyClusterLengths_E_B,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
1,
|
||||
1,
|
||||
InBlockCopyDataPerAccess_B,
|
||||
InBlockCopyDataPerAccess_B,
|
||||
AddressSpace::global,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::lds,
|
||||
InMemoryDataOperation::none>(
|
||||
{0, b_block_data_on_global}, {0, 0});
|
||||
|
||||
// weight tensor
|
||||
// global mem
|
||||
constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower(
|
||||
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
|
||||
|
||||
// LDS
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<EPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with multiple alignment
|
||||
// requirements
|
||||
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
|
||||
"GemmDataPerReadA alignment requirement is not satisfied");
|
||||
|
||||
// weight blockwise copy
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
0,
|
||||
1,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
AddressSpace::global,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::lds,
|
||||
InMemoryDataOperation::none>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[EPerBlock, KPerBlock] is in LDS
|
||||
// b_mtx[EPerBlocl, BPerBlock] is in LDS
|
||||
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
|
||||
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
|
||||
|
||||
// sanity check
|
||||
static_assert(
|
||||
KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0 &&
|
||||
BPerBlock % (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
|
||||
constexpr index_t GemmNRepeat =
|
||||
BPerBlock / (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k1_b0b1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_e_k_block_mtx_desc),
|
||||
decltype(b_e_b_block_mtx_desc),
|
||||
decltype(c_k0k1_b0b1_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS allocation for input and weight: be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_B,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t in_block_space =
|
||||
math::integer_least_multiple(in_e_b_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
__shared__ Float p_in_block_double[2 * in_block_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
Float p_out_thread[c_k0k1_b0b1_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread);
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global, p_in_block_double);
|
||||
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
|
||||
e_block_data_begin += 2 * EPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_in_block_now =
|
||||
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
|
||||
|
||||
Float* p_in_block_next =
|
||||
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0);
|
||||
|
||||
if(has_two_iteration_left) // if has 2 iteration left
|
||||
{
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
}
|
||||
}
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col;
|
||||
|
||||
// src descriptor
|
||||
constexpr auto out_k0_k1_b0_b1_thread_desc = make_native_tensor_descriptor_packed(
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat, GemmNPerThreadSubC>{});
|
||||
|
||||
// dst descriptor
|
||||
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
|
||||
constexpr index_t K0 = K / K1;
|
||||
constexpr index_t B0 = B / B1;
|
||||
|
||||
constexpr auto out_k_b_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
constexpr auto out_k0_k1_b0_b1_global_desc = transform_tensor_descriptor(
|
||||
out_k_b_global_desc,
|
||||
make_tuple(UnMerge<Sequence<K0, K1>>{}, UnMerge<Sequence<B0, B1>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
// output threadwise copy
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<
|
||||
decltype(out_k0_k1_b0_b1_thread_desc),
|
||||
decltype(out_k0_k1_b0_b1_global_desc),
|
||||
decltype(out_k0_k1_b0_b1_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 4, 1>::type,
|
||||
3,
|
||||
OutThreadCopyDataPerAccess_B,
|
||||
OutThreadCopyDataPerAccess_B,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::global>({0, 0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
b_thread_data_on_global / B1,
|
||||
b_thread_data_on_global % B1})
|
||||
.Run(p_out_thread, p_out_global);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -9,6 +9,12 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// This threadwise copy allow vector access of src and dst.
|
||||
// It allows the vector size to be different on src and dst.
|
||||
// The dimension of vector access can be different for src and dst.
|
||||
// The dimension access order can be different for src and dst.
|
||||
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
|
||||
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
|
||||
template <index_t BlockSize,
|
||||
typename BlockSrcDesc,
|
||||
typename BlockDstDesc,
|
||||
@@ -18,10 +24,10 @@ template <index_t BlockSize,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorAccessDim,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess,
|
||||
index_t SrcVectoReadDim,
|
||||
index_t DstVectorWriteDim,
|
||||
index_t SrcDataPerRead,
|
||||
index_t DstDataPerWrite,
|
||||
AddressSpace SrcAddressSpace = AddressSpace::generic,
|
||||
AddressSpace ThreadBufferAddressSpace = AddressSpace::generic,
|
||||
AddressSpace DstAddressSpace = AddressSpace::generic,
|
||||
@@ -146,8 +152,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
ThreadBufferDesc,
|
||||
ThreadSliceLengths,
|
||||
SrcDimAccessOrder,
|
||||
SrcVectorAccessDim,
|
||||
SrcDataPerAccess,
|
||||
SrcVectoReadDim,
|
||||
SrcDataPerRead,
|
||||
1,
|
||||
SrcAddressSpace,
|
||||
ThreadBufferAddressSpace,
|
||||
@@ -157,9 +163,9 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
BlockDstDesc,
|
||||
ThreadSliceLengths,
|
||||
DstDimAccessOrder,
|
||||
DstVectorAccessDim,
|
||||
DstVectorWriteDim,
|
||||
1,
|
||||
DstDataPerAccess,
|
||||
DstDataPerWrite,
|
||||
ThreadBufferAddressSpace,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>;
|
||||
|
||||
@@ -31,14 +31,24 @@ template <index_t GridSize,
|
||||
index_t KPerThreadLoop,
|
||||
index_t ThreadGemmDataPerReadM,
|
||||
index_t ThreadGemmDataPerReadN,
|
||||
typename ABlockCopySubLengths_K_M,
|
||||
typename ABlockCopyClusterLengths_K_M,
|
||||
index_t ABlockCopyDataPerAccess_M,
|
||||
typename BBlockCopySubLengths_K_N,
|
||||
typename BBlockCopyClusterLengths_K_N,
|
||||
index_t BBlockCopyDataPerAccess_N,
|
||||
index_t CThreadCopyDataPerAccess_N>
|
||||
struct GridwiseGemmTransposedANormalBNormalC_v1r1
|
||||
typename ABlockCopyThreadSliceLengths_K_M,
|
||||
typename ABlockCopyThreadClusterLengths_K_M,
|
||||
typename ABlockCopyThreadClusterArrangeOrder,
|
||||
typename ABlockCopySrcAccessOrder,
|
||||
index_t ABlockCopySrcVectorReadDim,
|
||||
index_t ABlockCopySrcDataPerRead,
|
||||
index_t ABlockCopyDstDataPerWrite_M,
|
||||
typename BBlockCopyThreadSliceLengths_K_N,
|
||||
typename BBlockCopyThreadClusterLengths_K_N,
|
||||
typename BBlockCopyThreadClusterArrangeOrder,
|
||||
typename BBlockCopySrcAccessOrder,
|
||||
index_t BBlockCopySrcVectorReadDim,
|
||||
index_t BBlockCopySrcDataPerRead,
|
||||
index_t BBlockCopyDstDataPerWrite_N,
|
||||
typename CThreadCopySrcDstAccessOrder,
|
||||
index_t CThreadCopySrcDstVectorReadWriteDim,
|
||||
index_t CThreadCopyDstDataPerWrite>
|
||||
struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
{
|
||||
__device__ void Run(const Float* __restrict__ p_a_global,
|
||||
const Float* __restrict__ p_b_global,
|
||||
@@ -55,8 +65,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
|
||||
constexpr auto N = b_k_n_global_desc.GetLengths()[1];
|
||||
|
||||
// lds max alignment
|
||||
constexpr index_t max_lds_align = math::lcm(ABlockCopyDataPerAccess_M,
|
||||
BBlockCopyDataPerAccess_N,
|
||||
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
ThreadGemmDataPerReadM,
|
||||
ThreadGemmDataPerReadN);
|
||||
|
||||
@@ -86,15 +96,15 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
|
||||
decltype(a_k_m_global_desc),
|
||||
decltype(a_k_m_block_desc),
|
||||
decltype(a_k_m_block_desc.GetLengths()),
|
||||
ABlockCopySubLengths_K_M,
|
||||
ABlockCopyClusterLengths_K_M,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
ABlockCopyThreadSliceLengths_K_M,
|
||||
ABlockCopyThreadClusterLengths_K_M,
|
||||
ABlockCopyThreadClusterArrangeOrder,
|
||||
ABlockCopySrcAccessOrder,
|
||||
Sequence<0, 1>,
|
||||
ABlockCopySrcVectorReadDim,
|
||||
1,
|
||||
1,
|
||||
ABlockCopyDataPerAccess_M,
|
||||
ABlockCopyDataPerAccess_M,
|
||||
ABlockCopySrcDataPerRead,
|
||||
ABlockCopyDstDataPerWrite_M,
|
||||
AddressSpace::global,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::lds,
|
||||
@@ -112,15 +122,15 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
|
||||
decltype(b_k_n_global_desc),
|
||||
decltype(b_k_n_block_desc),
|
||||
decltype(b_k_n_block_desc.GetLengths()),
|
||||
BBlockCopySubLengths_K_N,
|
||||
BBlockCopyClusterLengths_K_N,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
BBlockCopyThreadSliceLengths_K_N,
|
||||
BBlockCopyThreadClusterLengths_K_N,
|
||||
BBlockCopyThreadClusterArrangeOrder,
|
||||
BBlockCopySrcAccessOrder,
|
||||
Sequence<0, 1>,
|
||||
BBlockCopySrcVectorReadDim,
|
||||
1,
|
||||
1,
|
||||
BBlockCopyDataPerAccess_N,
|
||||
BBlockCopyDataPerAccess_N,
|
||||
BBlockCopySrcDataPerRead,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
AddressSpace::global,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::lds,
|
||||
@@ -304,10 +314,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_n0_n1_thread_desc),
|
||||
decltype(c_m0_m1_n0_n1_global_desc),
|
||||
decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()),
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
CThreadCopyDataPerAccess_N,
|
||||
CThreadCopyDataPerAccess_N,
|
||||
CThreadCopySrcDstAccessOrder,
|
||||
CThreadCopySrcDstVectorReadWriteDim,
|
||||
1,
|
||||
CThreadCopyDstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::global,
|
||||
CGlobalMemoryDataOperation>(
|
||||
|
||||
@@ -8,20 +8,19 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// This version use multi-index transformation
|
||||
// This threadwise copy allow vector access of src and dst.
|
||||
// It allows the vector size to be different on src and dst.
|
||||
// The dimensions of vector access should be the same on src and dst.
|
||||
// The dimension access order should be the same on src and dst.
|
||||
// It is designed for cases, where one of src and dst is register, and
|
||||
// the other is device memory or LDS
|
||||
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
|
||||
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
|
||||
template <typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t VectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess,
|
||||
typename SrcDstDimAccessOrder,
|
||||
index_t SrcDstVectorReadWriteDim,
|
||||
index_t SrcDataPerRead,
|
||||
index_t DstDataPerWrite,
|
||||
AddressSpace SrcAddressSpace = AddressSpace::generic,
|
||||
AddressSpace DstAddressSpace = AddressSpace::generic,
|
||||
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::none>
|
||||
@@ -39,16 +38,17 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
{
|
||||
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
|
||||
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::Size() &&
|
||||
nDim == DimAccessOrder::Size(),
|
||||
nDim == SrcDstDimAccessOrder::Size(),
|
||||
"wrong! # of dimensions not the same");
|
||||
|
||||
static_assert(is_valid_sequence_map<DimAccessOrder>{}, "wrong! map is not valid");
|
||||
static_assert(is_valid_sequence_map<SrcDstDimAccessOrder>{}, "wrong! map is not valid");
|
||||
|
||||
static_assert(
|
||||
SliceLengths{}[VectorAccessDim] % math::lcm(SrcDataPerAccess, DstDataPerAccess) == 0,
|
||||
"wrong! cannot evenly divide");
|
||||
static_assert(SliceLengths{}[SrcDstVectorReadWriteDim] %
|
||||
math::lcm(SrcDataPerRead, DstDataPerWrite) ==
|
||||
0,
|
||||
"wrong! cannot evenly divide");
|
||||
|
||||
// TODO:: sanity-check if vectorized memory access is allowed on src and dst
|
||||
// TODO:: sanity-check if vectorized memory read/write is allowed on src and dst
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v4r2()
|
||||
@@ -67,22 +67,20 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
mDstSliceOrigin = dst_slice_origin;
|
||||
}
|
||||
|
||||
// Will do padding check on src data: Read 0 if src data is in padding area.
|
||||
// Will do padding check on dst data: No write if dst data is in paddin area.
|
||||
template <typename SrcData, typename DstData>
|
||||
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
|
||||
{
|
||||
constexpr auto vector_access_dim = Number<VectorAccessDim>{};
|
||||
constexpr auto vector_access_dim = Number<SrcDstVectorReadWriteDim>{};
|
||||
|
||||
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
|
||||
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
|
||||
constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
|
||||
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
|
||||
|
||||
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerAccess, DstDataPerAccess)>{};
|
||||
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerRead, DstDataPerWrite)>{};
|
||||
|
||||
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
|
||||
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
|
||||
|
||||
ford<decltype(long_vector_access_lengths), DimAccessOrder>{}([&](
|
||||
ford<decltype(long_vector_access_lengths), SrcDstDimAccessOrder>{}([&](
|
||||
auto long_vector_access_id) {
|
||||
|
||||
// data id w.r.t slicing-window
|
||||
@@ -109,13 +107,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
|
||||
const auto src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
|
||||
|
||||
// Check src vector's padding situation, only check the first data in this src
|
||||
// Check src data's valid mapping situation, only check the first data in this src
|
||||
// vector. It's user's responsiblity to make sure all data in the src vector
|
||||
// has the same padding situation
|
||||
// has the valid/invalid mapping situation
|
||||
if(src_coord.IsUpperIndexMappedToValidOffset())
|
||||
{
|
||||
move_data<SrcData,
|
||||
SrcDataPerAccess,
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(
|
||||
@@ -141,13 +139,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
|
||||
const auto dst_coord = mDstSliceOrigin + (long_vector_data_begin_id + scalar_id);
|
||||
|
||||
// Check dst vector's padding situation, only check the first data in this dst
|
||||
// Check dst data's valid mapping situation, only check the first data in this dst
|
||||
// vector. It's user's responsiblity to make sure all data in the dst vector
|
||||
// has the same padding situation
|
||||
// has the valid/invalid mapping situation
|
||||
if(dst_coord.IsUpperIndexMappedToValidOffset())
|
||||
{
|
||||
move_data<DstData,
|
||||
DstDataPerAccess,
|
||||
DstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(
|
||||
@@ -165,20 +163,20 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
return Sequence<(Mask ? Lengths : 1)...>{};
|
||||
}
|
||||
|
||||
// Will do padding check on src data: Read 0 if src data is in padding area.
|
||||
// Will do padding check on dst data: No write if dst data is in paddin area.
|
||||
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
|
||||
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
|
||||
// This version is optimized for address calculation of src tensor
|
||||
// TODO: this function is not compiled to expected ISA
|
||||
template <typename SrcData, typename DstData>
|
||||
__device__ void Run_optimized_src_address_calculation(const SrcData* p_src,
|
||||
DstData* p_dst) const
|
||||
{
|
||||
constexpr auto vector_access_dim = Number<VectorAccessDim>{};
|
||||
constexpr auto vector_access_dim = Number<SrcDstVectorReadWriteDim>{};
|
||||
|
||||
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
|
||||
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
|
||||
constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
|
||||
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
|
||||
|
||||
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerAccess, DstDataPerAccess)>{};
|
||||
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerRead, DstDataPerWrite)>{};
|
||||
|
||||
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
|
||||
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
|
||||
@@ -187,10 +185,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
constexpr auto src_linear_dim_mask = SrcDesc::GetLinearDimensionMask();
|
||||
constexpr auto src_nonlinear_dim_mask = SrcDesc::GetNonLinearDimensionMask();
|
||||
|
||||
static_assert(src_linear_dim_mask.At(VectorAccessDim) ||
|
||||
long_vector_size == SrcDataPerAccess,
|
||||
"Warning! VectorAccessDim is not SrcDesc's linear dimension, performance "
|
||||
"would drop");
|
||||
static_assert(
|
||||
src_linear_dim_mask.At(SrcDstVectorReadWriteDim) || long_vector_size == SrcDataPerRead,
|
||||
"Warning! SrcDstVectorReadWriteDim is not SrcDesc's linear dimension, performance "
|
||||
"would drop");
|
||||
|
||||
// separate steps into linear and non-linear components, accoording to src tensor
|
||||
constexpr auto linear_long_vector_access_lengths =
|
||||
@@ -230,13 +228,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
p_src_long_vector[i] = 0;
|
||||
}
|
||||
|
||||
// Loop over VectorAccessDim, and load data from src to the
|
||||
// Loop over SrcDstVectorReadWriteDim, and load data from src to the
|
||||
// long-vector buffer.
|
||||
// If VectorAccessDim is src's linear dimension, then src's
|
||||
// If SrcDstVectorReadWriteDim is src's linear dimension, then src's
|
||||
// offset-diff due to this looping is known at compile-time. If
|
||||
// VectorAccessDim is src's nonlinear dimension, then src's
|
||||
// SrcDstVectorReadWriteDim is src's nonlinear dimension, then src's
|
||||
// offset-diff due to this looping is only known at run-time. For best
|
||||
// performance, VectorAccessDim, should be src's linear dimension
|
||||
// performance, SrcDstVectorReadWriteDim, should be src's linear dimension
|
||||
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
|
||||
{
|
||||
auto scalar_id = make_zero_array<index_t, nDim>();
|
||||
@@ -258,13 +256,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
src_coord.GetOffset() - src_nonlinear_coord.GetOffset();
|
||||
#endif
|
||||
|
||||
// Check src vector's padding situation, only check the first data in
|
||||
// this src vector. It's user's responsiblity to make sure all data in
|
||||
// the src vector has the same padding situation
|
||||
// Check src data's valid mapping situation, only check the first data in this
|
||||
// src
|
||||
// vector. It's user's responsiblity to make sure all data in the src vector
|
||||
// has the valid/invalid mapping situation
|
||||
if(src_coord.IsUpperIndexMappedToValidOffset())
|
||||
{
|
||||
move_data<SrcData,
|
||||
SrcDataPerAccess,
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(p_src,
|
||||
@@ -296,13 +295,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
const auto dst_coord = mDstSliceOrigin + (nonlinear_dim_data_steps +
|
||||
linear_dim_data_steps + scalar_id);
|
||||
|
||||
// Check dst vector's padding situation, only check the first data in
|
||||
// this dst vector. It's user's responsiblity to make sure all data in
|
||||
// the dst vector has the same padding situation
|
||||
// Check dst data's valid mapping situation, only check the first data in this
|
||||
// dst
|
||||
// vector. It's user's responsiblity to make sure all data in the dst vector
|
||||
// has the valid/invalid mapping situation
|
||||
if(dst_coord.IsUpperIndexMappedToValidOffset())
|
||||
{
|
||||
move_data<DstData,
|
||||
DstDataPerAccess,
|
||||
DstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(
|
||||
@@ -313,20 +313,18 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
});
|
||||
}
|
||||
|
||||
// Will do padding check on src data: Read 0 if src data is in padding area.
|
||||
// Will do padding check on dst data: No write if dst data is in paddin area.
|
||||
// This version is optimized for address calculation of dst tensor
|
||||
// TODO: this function is not compiled to expected ISA
|
||||
template <typename SrcData, typename DstData>
|
||||
__device__ void Run_optimized_dst_address_calculation(const SrcData* p_src,
|
||||
DstData* p_dst) const
|
||||
{
|
||||
constexpr auto vector_access_dim = Number<VectorAccessDim>{};
|
||||
constexpr auto vector_access_dim = Number<SrcDstVectorReadWriteDim>{};
|
||||
|
||||
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
|
||||
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
|
||||
constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
|
||||
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
|
||||
|
||||
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerAccess, DstDataPerAccess)>{};
|
||||
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerRead, DstDataPerWrite)>{};
|
||||
|
||||
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
|
||||
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
|
||||
@@ -335,10 +333,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
constexpr auto dst_linear_dim_mask = DstDesc::GetLinearDimensionMask();
|
||||
constexpr auto dst_nonlinear_dim_mask = DstDesc::GetNonLinearDimensionMask();
|
||||
|
||||
static_assert(dst_linear_dim_mask.At(VectorAccessDim) ||
|
||||
long_vector_size == DstDataPerAccess,
|
||||
"Warning! VectorAccessDim is not DstDesc's linear dimension, performance "
|
||||
"would drop");
|
||||
static_assert(
|
||||
dst_linear_dim_mask.At(SrcDstVectorReadWriteDim) || long_vector_size == DstDataPerWrite,
|
||||
"Warning! SrcDstVectorReadWriteDim is not DstDesc's linear dimension, performance "
|
||||
"would drop");
|
||||
|
||||
// separate steps into linear and non-linear components, accoording to dst tensor
|
||||
constexpr auto linear_long_vector_access_lengths =
|
||||
@@ -378,13 +376,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
p_src_long_vector[i] = 0;
|
||||
}
|
||||
|
||||
// Loop over VectorAccessDim, and load data from src to the
|
||||
// Loop over SrcDstVectorReadWriteDim, and load data from src to the
|
||||
// long-vector buffer.
|
||||
// If VectorAccessDim is dst's linear dimension, then dst's
|
||||
// If SrcDstVectorReadWriteDim is dst's linear dimension, then dst's
|
||||
// offset-diff due to this looping is known at compile-time. If
|
||||
// VectorAccessDim is dst's nonlinear dimension, then dst's
|
||||
// SrcDstVectorReadWriteDim is dst's nonlinear dimension, then dst's
|
||||
// offset-diff due to this looping is only known at run-time. For best
|
||||
// performance, VectorAccessDim, should be dst's linear dimension
|
||||
// performance, SrcDstVectorReadWriteDim, should be dst's linear dimension
|
||||
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
|
||||
{
|
||||
auto scalar_id = make_zero_array<index_t, nDim>();
|
||||
@@ -397,13 +395,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
const auto src_coord = mSrcSliceOrigin + (nonlinear_dim_data_steps +
|
||||
linear_dim_data_steps + scalar_id);
|
||||
|
||||
// Check src vector's padding situation, only check the first data in
|
||||
// this src vector. It's user's responsiblity to make sure all data in
|
||||
// the src vector has the same padding situation
|
||||
// Check src data's valid mapping situation, only check the first data in this
|
||||
// src
|
||||
// vector. It's user's responsiblity to make sure all data in the src vector
|
||||
// has the valid/invalid mapping situation
|
||||
if(src_coord.IsUpperIndexMappedToValidOffset())
|
||||
{
|
||||
move_data<SrcData,
|
||||
SrcDataPerAccess,
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(
|
||||
@@ -441,13 +440,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
dst_coord.GetOffset() - dst_nonlinear_coord.GetOffset();
|
||||
#endif
|
||||
|
||||
// Check dst vector's padding situation, only check the first data in
|
||||
// this dst vector. It's user's responsiblity to make sure all data in
|
||||
// the dst vector has the same padding situation
|
||||
// Check dst data's valid mapping situation, only check the first data in this
|
||||
// dst
|
||||
// vector. It's user's responsiblity to make sure all data in the dst vector
|
||||
// has the valid/invalid mapping situation
|
||||
if(dst_coord.IsUpperIndexMappedToValidOffset())
|
||||
{
|
||||
move_data<DstData,
|
||||
DstDataPerAccess,
|
||||
DstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(p_dst_long_vector,
|
||||
|
||||
@@ -11,8 +11,8 @@ template <typename T,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
|
||||
Tensor<T>& in_nchw,
|
||||
WeiDesc wei_kcyx_desc,
|
||||
@@ -21,8 +21,8 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
|
||||
const Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
@@ -62,24 +62,26 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopySubLengths = Sequence<1, 4>; // Gemm-K, Gemm-M
|
||||
using GemmABlockCopyClusterLengths = Sequence<8, 32>; // Gemm-K, Gemm-M
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopyDataPerAccess = 4; // Gemm-M
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
|
||||
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t GemmM = C * Y * X;
|
||||
constexpr index_t GemmN = N * Ho * Wo;
|
||||
|
||||
constexpr index_t GridSize = ((GemmM + GemmMPerBlock - 1) / GemmMPerBlock) *
|
||||
((GemmN + GemmNPerBlock - 1) / GemmNPerBlock);
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
@@ -93,8 +95,8 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
@@ -107,13 +109,15 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopySubLengths,
|
||||
GemmABlockCopyClusterLengths,
|
||||
GemmABlockCopyDataPerAccess,
|
||||
GemmBBlockCopySubLengths,
|
||||
GemmBBlockCopyClusterLengths,
|
||||
GemmBBlockCopyDataPerAccess,
|
||||
GemmCThreadCopyDataPerAccess>{};
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
|
||||
@@ -11,8 +11,8 @@ template <typename T,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
|
||||
Tensor<T>& in_nchw,
|
||||
WeiDesc wei_kcyx_desc,
|
||||
@@ -21,8 +21,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
const Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
@@ -68,54 +68,26 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-M
|
||||
using GemmABlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-M
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockCopyDataPerAccess = 1; // Gemm-M
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
|
||||
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
|
||||
#elif 0
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopySubLengths = Sequence<1, 4>; // Gemm-K, Gemm-M
|
||||
using GemmABlockCopyClusterLengths = Sequence<8, 32>; // Gemm-K, Gemm-M
|
||||
|
||||
constexpr index_t GemmABlockCopyDataPerAccess = 4; // Gemm-M
|
||||
|
||||
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
|
||||
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
|
||||
|
||||
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
|
||||
|
||||
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
|
||||
// simplicity
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; // may be wrong
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; // may be wrong
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
@@ -126,12 +98,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t Htilda = Ho + right_pad_ho;
|
||||
constexpr index_t Wtilda = Wo + right_pad_wo;
|
||||
|
||||
constexpr index_t GemmK = K * Ydot * Xdot;
|
||||
constexpr index_t GemmM = C * Ytilda * Xtilda;
|
||||
constexpr index_t GemmN = N * Htilda * Wtilda;
|
||||
|
||||
constexpr index_t GridSize = ((GemmM + GemmMPerBlock - 1) / GemmMPerBlock) *
|
||||
((GemmN + GemmNPerBlock - 1) / GemmNPerBlock);
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
@@ -145,8 +116,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
@@ -159,13 +130,15 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopySubLengths,
|
||||
GemmABlockCopyClusterLengths,
|
||||
GemmABlockCopyDataPerAccess,
|
||||
GemmBBlockCopySubLengths,
|
||||
GemmBBlockCopyClusterLengths,
|
||||
GemmBBlockCopyDataPerAccess,
|
||||
GemmCThreadCopyDataPerAccess>{};
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_convolution_kernel_wrapper.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
template <class T,
|
||||
class InDesc,
|
||||
@@ -11,8 +11,8 @@ template <class T,
|
||||
class OutDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class LeftPads,
|
||||
class RightPads>
|
||||
class InLeftPads,
|
||||
class InRightPads>
|
||||
void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
@@ -21,8 +21,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
@@ -32,9 +32,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_nchw_desc = InDesc{};
|
||||
constexpr auto wei_kcyx_desc = WeiDesc{};
|
||||
constexpr auto out_nkhw_desc = OutDesc{};
|
||||
constexpr auto in_nchw_desc =
|
||||
make_native_tensor_descriptor(InDesc::GetLengths(), InDesc::GetStrides());
|
||||
constexpr auto wei_kcyx_desc =
|
||||
make_native_tensor_descriptor(WeiDesc::GetLengths(), WeiDesc::GetStrides());
|
||||
constexpr auto out_nkhw_desc =
|
||||
make_native_tensor_descriptor(OutDesc::GetLengths(), OutDesc::GetStrides());
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLength(I0);
|
||||
constexpr index_t K = out_nkhw_desc.GetLength(I1);
|
||||
@@ -51,198 +54,174 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// BlockSize = 256, EPerBlock = 8
|
||||
// BlockSize = 256, GemmKPerBlock = 8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_B = Sequence<4, 1>;
|
||||
using InBlockCopyClusterLengths_E_B = Sequence<2, 128>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t InBlockCopyDataPerAccess_B = 1;
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerAccess_B = 1;
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 0
|
||||
// BlockSize = 256, EPerBlock = 8
|
||||
// BlockSize = 256, GemmKPerBlock = 8
|
||||
// 1x1 filter, 8x8 image
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_B = Sequence<1, 4>;
|
||||
using InBlockCopyClusterLengths_E_B = Sequence<8, 32>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t InBlockCopyDataPerAccess_B = 4;
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerAccess_B = 4;
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#elif 0
|
||||
// BlockSize = 256, EPerBlock = 16
|
||||
// BlockSize = 256, GemmKPerBlock = 16
|
||||
// 1x1 filter, 8x8 image
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t EPerBlock = 16;
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_B = Sequence<2, 4>;
|
||||
using InBlockCopyClusterLengths_E_B = Sequence<8, 32>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
|
||||
|
||||
constexpr index_t InBlockCopyDataPerAccess_B = 4;
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerAccess_B = 4;
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#elif 1
|
||||
// 1x1 filter, 14x14 image
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_B = Sequence<2, 2>;
|
||||
using InBlockCopyClusterLengths_E_B = Sequence<4, 64>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t InBlockCopyDataPerAccess_B = 2;
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 2;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerAccess_B = 2;
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 2;
|
||||
#endif
|
||||
|
||||
constexpr index_t B = N * Ho * Wo;
|
||||
constexpr index_t GemmM = K;
|
||||
constexpr index_t GemmN = N * Ho * Wo;
|
||||
|
||||
constexpr index_t GridSize =
|
||||
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
constexpr auto gridwise_conv =
|
||||
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
EPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
InBlockCopySubLengths_E_B,
|
||||
InBlockCopyClusterLengths_E_B,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
InBlockCopyDataPerAccess_B,
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
OutThreadCopyDataPerAccess_B>{};
|
||||
constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
ThreadGemmDataPerReadM,
|
||||
ThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmK,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
|
||||
@@ -21,7 +21,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
constexpr index_t N = 8;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 16;
|
||||
|
||||
@@ -43,7 +43,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
@@ -250,7 +250,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
|
||||
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
|
||||
constexpr index_t N = 128;
|
||||
@@ -296,7 +296,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 1x7 filter, 0x3 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
@@ -403,7 +403,7 @@ int main(int argc, char* argv[])
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
nrepeat);
|
||||
#elif 1
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
|
||||
Reference in New Issue
Block a user