mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 21:39:15 +00:00
MIopen integration (#13)
* update for miopen integration: cosmetic refactor
[ROCm/composable_kernel commit: 1a66e35b6f]
This commit is contained in:
@@ -49,20 +49,20 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
@@ -79,6 +79,36 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#endif
|
||||
|
||||
constexpr index_t GemmM = C * Y * X;
|
||||
@@ -104,13 +134,13 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
|
||||
@@ -66,13 +66,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
@@ -96,13 +96,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
@@ -127,13 +127,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
@@ -152,33 +152,33 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#endif
|
||||
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
constexpr index_t HTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
constexpr index_t HTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
|
||||
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
|
||||
|
||||
constexpr index_t GemmM = C * Ytilda * Xtilda;
|
||||
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim;
|
||||
constexpr index_t GemmM = C * YTilda * XTilda;
|
||||
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
@@ -200,13 +200,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
|
||||
@@ -66,13 +66,13 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
@@ -91,33 +91,33 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
constexpr index_t HTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
constexpr index_t HTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
|
||||
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
|
||||
|
||||
constexpr index_t GemmM = C;
|
||||
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim;
|
||||
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
@@ -139,13 +139,13 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
|
||||
@@ -69,13 +69,13 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
@@ -99,13 +99,13 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
@@ -124,33 +124,33 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
constexpr index_t HTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
constexpr index_t HTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
|
||||
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
|
||||
|
||||
constexpr index_t GemmM = C;
|
||||
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim;
|
||||
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
@@ -159,7 +159,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
using GridwiseConv = GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw<
|
||||
using GridwiseConvBwdData = GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
@@ -174,13 +174,13 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
@@ -196,21 +196,29 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
|
||||
static_for<0, GridwiseConv::GetNumberOfGemm(), 1>{}([&](auto gemm_id_) {
|
||||
static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id_) {
|
||||
constexpr index_t gemm_id = decltype(gemm_id_){};
|
||||
|
||||
launch_kernel(run_gridwise_convolution_backward_data_v4r1<GridwiseConv,
|
||||
gemm_id,
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id);
|
||||
constexpr index_t gemm_k = gemm_sizes.At(2);
|
||||
constexpr bool is_gemm_not_empty = gemm_k > 0;
|
||||
|
||||
// only compile and run if GEMM is no empty
|
||||
static_if<is_gemm_not_empty>{}([&](auto fwd) {
|
||||
launch_kernel(
|
||||
run_gridwise_convolution_backward_data_v4r1<GridwiseConvBwdData,
|
||||
fwd(gemm_id),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
});
|
||||
});
|
||||
|
||||
timer.End();
|
||||
|
||||
Reference in New Issue
Block a user