mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
update implicit GEMM forward v4r4 to use gridwise gemm (#9)
* updated fwd v4r4 to use gridwise gemm * updated gridwise gemm api calls in bwd-data v1r1 and v2r1
This commit is contained in:
@@ -11,8 +11,8 @@ template <typename T,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
|
||||
Tensor<T>& in_nchw,
|
||||
WeiDesc wei_kcyx_desc,
|
||||
@@ -21,8 +21,8 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
|
||||
const Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
@@ -62,24 +62,26 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopySubLengths = Sequence<1, 4>; // Gemm-K, Gemm-M
|
||||
using GemmABlockCopyClusterLengths = Sequence<8, 32>; // Gemm-K, Gemm-M
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopyDataPerAccess = 4; // Gemm-M
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
|
||||
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t GemmM = C * Y * X;
|
||||
constexpr index_t GemmN = N * Ho * Wo;
|
||||
|
||||
constexpr index_t GridSize = ((GemmM + GemmMPerBlock - 1) / GemmMPerBlock) *
|
||||
((GemmN + GemmNPerBlock - 1) / GemmNPerBlock);
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
@@ -93,8 +95,8 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
@@ -107,13 +109,15 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopySubLengths,
|
||||
GemmABlockCopyClusterLengths,
|
||||
GemmABlockCopyDataPerAccess,
|
||||
GemmBBlockCopySubLengths,
|
||||
GemmBBlockCopyClusterLengths,
|
||||
GemmBBlockCopyDataPerAccess,
|
||||
GemmCThreadCopyDataPerAccess>{};
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
|
||||
@@ -11,8 +11,8 @@ template <typename T,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
|
||||
Tensor<T>& in_nchw,
|
||||
WeiDesc wei_kcyx_desc,
|
||||
@@ -21,8 +21,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
const Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
@@ -68,54 +68,26 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-M
|
||||
using GemmABlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-M
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockCopyDataPerAccess = 1; // Gemm-M
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
|
||||
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
|
||||
#elif 0
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopySubLengths = Sequence<1, 4>; // Gemm-K, Gemm-M
|
||||
using GemmABlockCopyClusterLengths = Sequence<8, 32>; // Gemm-K, Gemm-M
|
||||
|
||||
constexpr index_t GemmABlockCopyDataPerAccess = 4; // Gemm-M
|
||||
|
||||
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
|
||||
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
|
||||
|
||||
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
|
||||
|
||||
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
|
||||
// simplicity
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; // may be wrong
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; // may be wrong
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
@@ -126,12 +98,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t Htilda = Ho + right_pad_ho;
|
||||
constexpr index_t Wtilda = Wo + right_pad_wo;
|
||||
|
||||
constexpr index_t GemmK = K * Ydot * Xdot;
|
||||
constexpr index_t GemmM = C * Ytilda * Xtilda;
|
||||
constexpr index_t GemmN = N * Htilda * Wtilda;
|
||||
|
||||
constexpr index_t GridSize = ((GemmM + GemmMPerBlock - 1) / GemmMPerBlock) *
|
||||
((GemmN + GemmNPerBlock - 1) / GemmNPerBlock);
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
@@ -145,8 +116,8 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
@@ -159,13 +130,15 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopySubLengths,
|
||||
GemmABlockCopyClusterLengths,
|
||||
GemmABlockCopyDataPerAccess,
|
||||
GemmBBlockCopySubLengths,
|
||||
GemmBBlockCopyClusterLengths,
|
||||
GemmBBlockCopyDataPerAccess,
|
||||
GemmCThreadCopyDataPerAccess>{};
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_convolution_kernel_wrapper.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
template <class T,
|
||||
class InDesc,
|
||||
@@ -11,8 +11,8 @@ template <class T,
|
||||
class OutDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class LeftPads,
|
||||
class RightPads>
|
||||
class InLeftPads,
|
||||
class InRightPads>
|
||||
void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
@@ -21,8 +21,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
@@ -32,9 +32,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_nchw_desc = InDesc{};
|
||||
constexpr auto wei_kcyx_desc = WeiDesc{};
|
||||
constexpr auto out_nkhw_desc = OutDesc{};
|
||||
constexpr auto in_nchw_desc =
|
||||
make_native_tensor_descriptor(InDesc::GetLengths(), InDesc::GetStrides());
|
||||
constexpr auto wei_kcyx_desc =
|
||||
make_native_tensor_descriptor(WeiDesc::GetLengths(), WeiDesc::GetStrides());
|
||||
constexpr auto out_nkhw_desc =
|
||||
make_native_tensor_descriptor(OutDesc::GetLengths(), OutDesc::GetStrides());
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLength(I0);
|
||||
constexpr index_t K = out_nkhw_desc.GetLength(I1);
|
||||
@@ -51,198 +54,174 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// BlockSize = 256, EPerBlock = 8
|
||||
// BlockSize = 256, GemmKPerBlock = 8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_B = Sequence<4, 1>;
|
||||
using InBlockCopyClusterLengths_E_B = Sequence<2, 128>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t InBlockCopyDataPerAccess_B = 1;
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerAccess_B = 1;
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 0
|
||||
// BlockSize = 256, EPerBlock = 8
|
||||
// BlockSize = 256, GemmKPerBlock = 8
|
||||
// 1x1 filter, 8x8 image
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_B = Sequence<1, 4>;
|
||||
using InBlockCopyClusterLengths_E_B = Sequence<8, 32>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t InBlockCopyDataPerAccess_B = 4;
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerAccess_B = 4;
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#elif 0
|
||||
// BlockSize = 256, EPerBlock = 16
|
||||
// BlockSize = 256, GemmKPerBlock = 16
|
||||
// 1x1 filter, 8x8 image
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t EPerBlock = 16;
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_B = Sequence<2, 4>;
|
||||
using InBlockCopyClusterLengths_E_B = Sequence<8, 32>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
|
||||
|
||||
constexpr index_t InBlockCopyDataPerAccess_B = 4;
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerAccess_B = 4;
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#elif 1
|
||||
// 1x1 filter, 14x14 image
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_B = Sequence<2, 2>;
|
||||
using InBlockCopyClusterLengths_E_B = Sequence<4, 64>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t InBlockCopyDataPerAccess_B = 2;
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 2;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerAccess_B = 2;
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 2;
|
||||
#endif
|
||||
|
||||
constexpr index_t B = N * Ho * Wo;
|
||||
constexpr index_t GemmM = K;
|
||||
constexpr index_t GemmN = N * Ho * Wo;
|
||||
|
||||
constexpr index_t GridSize =
|
||||
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
constexpr auto gridwise_conv =
|
||||
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
EPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
InBlockCopySubLengths_E_B,
|
||||
InBlockCopyClusterLengths_E_B,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
InBlockCopyDataPerAccess_B,
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
OutThreadCopyDataPerAccess_B>{};
|
||||
constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
ThreadGemmDataPerReadM,
|
||||
ThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmK,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
|
||||
@@ -21,7 +21,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
constexpr index_t N = 8;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 16;
|
||||
|
||||
@@ -43,7 +43,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
@@ -250,7 +250,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
|
||||
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
|
||||
constexpr index_t N = 128;
|
||||
@@ -296,7 +296,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 1x7 filter, 0x3 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
@@ -403,7 +403,7 @@ int main(int argc, char* argv[])
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
nrepeat);
|
||||
#elif 1
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
|
||||
Reference in New Issue
Block a user