mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
MIopen integration (#13)
* update for miopen integration: cosmetic refactor
This commit is contained in:
@@ -114,10 +114,10 @@ struct GridwiseCol2Im_eb_nchw
|
||||
1,
|
||||
BlockCopyDataPerAccess_B,
|
||||
BlockCopyDataPerAccess_B,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::global,
|
||||
InMemoryDataOperation::atomic_add>(
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Global,
|
||||
InMemoryDataOperation::AtomicAdd>(
|
||||
{e_block_data_on_global, b_block_data_on_global},
|
||||
{e_block_data_on_global, b_block_data_on_global});
|
||||
|
||||
|
||||
@@ -25,15 +25,15 @@ template <index_t GridSize,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMPerThread,
|
||||
index_t GemmNPerThread,
|
||||
index_t GemmKPerThread,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmThreadGemmDataPerReadM,
|
||||
index_t GemmThreadGemmDataPerReadN,
|
||||
index_t ThreadGemmAThreadCopySrcDataPerRead_GemmM,
|
||||
index_t ThreadGemmAThreadCopySrcDataPerRead_GemmN,
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmN,
|
||||
@@ -75,25 +75,20 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
// TODO: this logic may not be correct for bwd-data
|
||||
static_assert(
|
||||
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
|
||||
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
//\todo static_assert for global vector load/store
|
||||
// statc_assert();
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_gemmk_gemmm_global_desc =
|
||||
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3);
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_k_b_global_desc =
|
||||
constexpr auto out_gemmk_gemmn_global_desc =
|
||||
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_k_e_global_desc =
|
||||
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3);
|
||||
|
||||
// input tensor
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
@@ -116,38 +111,42 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_e_b_global_desc = transform_tensor_descriptor(
|
||||
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// GEMM
|
||||
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
|
||||
? InMemoryDataOperation::none
|
||||
: InMemoryDataOperation::atomic_add;
|
||||
// \todo there are more combinations of Y, ConvDilationH and ConvStrideH that don't need
|
||||
// atomic, find out all of them
|
||||
constexpr bool not_need_atomic = (ConvStrideH >= ConvDilationH * (Y - 1) + 1) and
|
||||
(ConvStrideW >= ConvDilationW * (X - 1) + 1);
|
||||
|
||||
constexpr auto in_memory_op =
|
||||
not_need_atomic ? InMemoryDataOperation::Set : InMemoryDataOperation::AtomicAdd;
|
||||
|
||||
constexpr auto gridwise_gemm =
|
||||
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_k_e_global_desc),
|
||||
decltype(out_k_b_global_desc),
|
||||
decltype(in_e_b_global_desc),
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
decltype(out_gemmk_gemmn_global_desc),
|
||||
decltype(in_gemmm_gemmn_global_desc),
|
||||
in_memory_op,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
ThreadGemmAThreadCopySrcDataPerRead_GemmM,
|
||||
ThreadGemmAThreadCopySrcDataPerRead_GemmN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<0, 1>,
|
||||
|
||||
@@ -147,10 +147,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
|
||||
2,
|
||||
OutBlockCopySrcDataPerRead_B,
|
||||
OutBlockCopyDstDataPerWrite_N0,
|
||||
AddressSpace::global,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::lds,
|
||||
InMemoryDataOperation::none>(
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Lds,
|
||||
InMemoryDataOperation::Set>(
|
||||
{0, b_block_data_on_global, 0}, {0, 0, 0});
|
||||
|
||||
// weight tensor
|
||||
@@ -187,10 +187,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
|
||||
2,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_C0,
|
||||
AddressSpace::global,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::lds,
|
||||
InMemoryDataOperation::none>(
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Lds,
|
||||
InMemoryDataOperation::Set>(
|
||||
{0, e_block_data_on_global, 0}, {0, 0, 0});
|
||||
|
||||
// GEMM definition
|
||||
@@ -356,10 +356,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
|
||||
#if 1 // debug
|
||||
// input: register to global memory, atomic add
|
||||
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
|
||||
? InMemoryDataOperation::none
|
||||
: InMemoryDataOperation::atomic_add;
|
||||
? InMemoryDataOperation::Set
|
||||
: InMemoryDataOperation::AtomicAdd;
|
||||
#else
|
||||
constexpr auto in_memory_op = InMemoryDataOperation::atomic_add;
|
||||
constexpr auto in_memory_op = InMemoryDataOperation::AtomicAdd;
|
||||
#endif
|
||||
|
||||
constexpr index_t E1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
@@ -432,8 +432,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
|
||||
4,
|
||||
1,
|
||||
InThreadCopyDstDataPerWrite_B,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Global,
|
||||
in_memory_op>({0, 0, 0, 0, 0, 0},
|
||||
{e_thread_data_on_global / E1,
|
||||
e_thread_data_on_global % E1,
|
||||
|
||||
@@ -8,9 +8,9 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = C * Ytilda * Xtilda;
|
||||
// GemmN = N * HtildaNonZero * WtildaNonZero;
|
||||
// GemmK = K * Ydot * Xdot;
|
||||
// GemmM = C * YTilda * XTilda;
|
||||
// GemmN = N * HTildaSlice * WTildaSlice;
|
||||
// GemmK = K * YDot * XDot;
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
@@ -25,13 +25,13 @@ template <index_t GridSize,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMPerThread,
|
||||
index_t GemmNPerThread,
|
||||
index_t GemmKPerThread,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmThreadGemmDataPerReadM,
|
||||
index_t GemmThreadGemmDataPerReadN,
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
@@ -81,32 +81,32 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
"be violated");
|
||||
#endif
|
||||
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t Htilda =
|
||||
constexpr index_t HTilda =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda =
|
||||
constexpr index_t WTilda =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
constexpr index_t HTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
constexpr index_t HTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
|
||||
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
|
||||
@@ -114,17 +114,17 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Y,
|
||||
Sequence<Ydot, Ytilda>,
|
||||
Sequence<ConvStrideH / gcd_stride_dilation_h, 1, 0>>{},
|
||||
Sequence<YDot, YTilda>,
|
||||
Sequence<ConvStrideH / GcdStrideDilationH, 1, 0>>{},
|
||||
Embed<X,
|
||||
Sequence<Xdot, Xtilda>,
|
||||
Sequence<ConvStrideW / gcd_stride_dilation_w, 1, 0>>{}),
|
||||
Sequence<XDot, XTilda>,
|
||||
Sequence<ConvStrideW / GcdStrideDilationW, 1, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
|
||||
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{}, Merge<Sequence<C, Ytilda, Xtilda>>{}),
|
||||
make_tuple(Merge<Sequence<K, YDot, XDot>>{}, Merge<Sequence<C, YTilda, XTilda>>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
@@ -134,33 +134,33 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
Embed<Ho,
|
||||
Sequence<Ydot, Htilda>,
|
||||
Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>>{},
|
||||
Sequence<YDot, HTilda>,
|
||||
Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>>{},
|
||||
Embed<Wo,
|
||||
Sequence<Xdot, Wtilda>,
|
||||
Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>>{}),
|
||||
Sequence<XDot, WTilda>,
|
||||
Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc =
|
||||
constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<Ytilda>{},
|
||||
PassThrough<Xtilda>{},
|
||||
Slice<Sequence<Htilda, Wtilda>,
|
||||
Sequence<HtildaLeft, WtildaLeft>,
|
||||
Sequence<HtildaRight, WtildaRight>>{}),
|
||||
PassThrough<YTilda>{},
|
||||
PassThrough<XTilda>{},
|
||||
Slice<Sequence<HTilda, WTilda>,
|
||||
Sequence<HTildaLeft, WTildaLeft>,
|
||||
Sequence<HTildaRight, WTildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
constexpr auto out_gemmk_gemmn_global_desc =
|
||||
transform_tensor_descriptor(out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
|
||||
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{},
|
||||
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
|
||||
transform_tensor_descriptor(out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YDot, XDot>>{},
|
||||
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
@@ -188,35 +188,35 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Hip,
|
||||
Sequence<Ytilda, Htilda>,
|
||||
Sequence<YTilda, HTilda>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>,
|
||||
in_skip_all_out_of_bound_check>{},
|
||||
Embed<Wip,
|
||||
Sequence<Xtilda, Wtilda>,
|
||||
Sequence<XTilda, WTilda>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>,
|
||||
in_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc =
|
||||
constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<Ytilda>{},
|
||||
PassThrough<Xtilda>{},
|
||||
Slice<Sequence<Htilda, Wtilda>,
|
||||
Sequence<HtildaLeft, WtildaLeft>,
|
||||
Sequence<HtildaRight, WtildaRight>>{}),
|
||||
PassThrough<YTilda>{},
|
||||
PassThrough<XTilda>{},
|
||||
Slice<Sequence<HTilda, WTilda>,
|
||||
Sequence<HTildaLeft, WTildaLeft>,
|
||||
Sequence<HTildaRight, WTildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
constexpr auto in_gemmm_gemmn_global_desc =
|
||||
transform_tensor_descriptor(in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{},
|
||||
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
|
||||
transform_tensor_descriptor(in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc,
|
||||
make_tuple(Merge<Sequence<C, YTilda, XTilda>>{},
|
||||
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
@@ -229,17 +229,17 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
decltype(out_gemmk_gemmn_global_desc),
|
||||
decltype(in_gemmm_gemmn_global_desc),
|
||||
InMemoryDataOperation::none,
|
||||
InMemoryDataOperation::Set,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
|
||||
@@ -8,10 +8,10 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Ytilda*Xtilda number of GEMMs
|
||||
// GemmM = C;
|
||||
// GemmN = N * HtildaNonZero * WtildaNonZero;
|
||||
// GemmK = K * YdotNonZero * XdotNonZero;
|
||||
// Number of GEMMs: YTilda * XTilda
|
||||
// GemmM = C
|
||||
// GemmN = N * HTildaSlice * WTildaSlice
|
||||
// GemmK = K * YDotSlice * XDotSlice
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
@@ -26,13 +26,13 @@ template <index_t GridSize,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMPerThread,
|
||||
index_t GemmNPerThread,
|
||||
index_t GemmKPerThread,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmThreadGemmDataPerReadM,
|
||||
index_t GemmThreadGemmDataPerReadN,
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
@@ -110,32 +110,32 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
|
||||
"be violated");
|
||||
#endif
|
||||
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t Htilda =
|
||||
constexpr index_t HTilda =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda =
|
||||
constexpr index_t WTilda =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
constexpr index_t HTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
constexpr index_t HTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
|
||||
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
|
||||
|
||||
constexpr bool wei_skip_all_out_of_bound_check = true;
|
||||
|
||||
@@ -145,12 +145,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Y,
|
||||
Sequence<Ydot, Ytilda>,
|
||||
Sequence<ConvStrideH / gcd_stride_dilation_h, 1, 0>,
|
||||
Sequence<YDot, YTilda>,
|
||||
Sequence<ConvStrideH / GcdStrideDilationH, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{},
|
||||
Embed<X,
|
||||
Sequence<Xdot, Xtilda>,
|
||||
Sequence<ConvStrideW / gcd_stride_dilation_w, 1, 0>,
|
||||
Sequence<XDot, XTilda>,
|
||||
Sequence<ConvStrideW / GcdStrideDilationW, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
@@ -167,26 +167,26 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
Embed<Ho,
|
||||
Sequence<Ydot, Htilda>,
|
||||
Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>,
|
||||
Sequence<YDot, HTilda>,
|
||||
Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{},
|
||||
Embed<Wo,
|
||||
Sequence<Xdot, Wtilda>,
|
||||
Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>,
|
||||
Sequence<XDot, WTilda>,
|
||||
Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc =
|
||||
constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<Ytilda>{},
|
||||
PassThrough<Xtilda>{},
|
||||
Slice<Sequence<Htilda, Wtilda>,
|
||||
Sequence<HtildaLeft, WtildaLeft>,
|
||||
Sequence<HtildaRight, WtildaRight>>{}),
|
||||
PassThrough<YTilda>{},
|
||||
PassThrough<XTilda>{},
|
||||
Slice<Sequence<HTilda, WTilda>,
|
||||
Sequence<HTildaLeft, WTildaLeft>,
|
||||
Sequence<HTildaRight, WTildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
@@ -216,26 +216,26 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Hip,
|
||||
Sequence<Ytilda, Htilda>,
|
||||
Sequence<YTilda, HTilda>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>,
|
||||
in_skip_all_out_of_bound_check>{},
|
||||
Embed<Wip,
|
||||
Sequence<Xtilda, Wtilda>,
|
||||
Sequence<XTilda, WTilda>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>,
|
||||
in_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc =
|
||||
constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<Ytilda>{},
|
||||
PassThrough<Xtilda>{},
|
||||
Slice<Sequence<Htilda, Wtilda>,
|
||||
Sequence<HtildaLeft, WtildaLeft>,
|
||||
Sequence<HtildaRight, WtildaRight>>{}),
|
||||
PassThrough<YTilda>{},
|
||||
PassThrough<XTilda>{},
|
||||
Slice<Sequence<HTilda, WTilda>,
|
||||
Sequence<HTildaLeft, WTildaLeft>,
|
||||
Sequence<HTildaRight, WTildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
@@ -246,54 +246,49 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
|
||||
|
||||
__shared__ Float p_shared_block[shared_block_size];
|
||||
|
||||
#if 1 // debug
|
||||
static_for<0, Ytilda, 1>{}([&](auto ytilda_) {
|
||||
static_for<0, Xtilda, 1>{}([&](auto xtilda_) {
|
||||
#else
|
||||
static_for<0, 1, 1>{}([&](auto ytilda_) {
|
||||
static_for<0, 1, 1>{}([&](auto xtilda_) {
|
||||
#endif
|
||||
constexpr index_t ytilda = decltype(ytilda_){};
|
||||
constexpr index_t xtilda = decltype(xtilda_){};
|
||||
static_for<0, YTilda, 1>{}([&](auto iYTilda_) {
|
||||
static_for<0, XTilda, 1>{}([&](auto iXTilda_) {
|
||||
constexpr index_t iYTilda = decltype(iYTilda_){};
|
||||
constexpr index_t iXTilda = decltype(iXTilda_){};
|
||||
|
||||
constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot;
|
||||
constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot;
|
||||
constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot;
|
||||
constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot;
|
||||
|
||||
// A matrix
|
||||
constexpr auto wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc =
|
||||
constexpr auto wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Slice<Sequence<Ydot, Xdot>,
|
||||
Slice<Sequence<YDot, XDot>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<YdotNonZero, XdotNonZero>>{},
|
||||
Slice<Sequence<Ytilda, Xtilda>,
|
||||
Sequence<ytilda, xtilda>,
|
||||
Sequence<ytilda + 1, xtilda + 1>>{}),
|
||||
Sequence<YDotSlice, XDotSlice>>{},
|
||||
Slice<Sequence<YTilda, XTilda>,
|
||||
Sequence<iYTilda, iXTilda>,
|
||||
Sequence<iYTilda + 1, iXTilda + 1>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{},
|
||||
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{},
|
||||
Merge<Sequence<C, 1, 1>>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// B matrix
|
||||
constexpr auto out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc =
|
||||
constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
|
||||
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<HtildaTrim>{},
|
||||
PassThrough<WtildaTrim>{},
|
||||
Slice<Sequence<Ydot, Xdot>,
|
||||
PassThrough<HTildaSlice>{},
|
||||
PassThrough<WTildaSlice>{},
|
||||
Slice<Sequence<YDot, XDot>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<YdotNonZero, XdotNonZero>>{}),
|
||||
Sequence<YDotSlice, XDotSlice>>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
@@ -306,23 +301,23 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
|
||||
Sequence<2, 4>{}));
|
||||
|
||||
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{},
|
||||
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
|
||||
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{},
|
||||
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C matrix
|
||||
constexpr auto in_n_c_1_htildatrim_1_wtildatrim_global_desc =
|
||||
constexpr auto in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
|
||||
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<HtildaTrim>{},
|
||||
PassThrough<WtildaTrim>{},
|
||||
Slice<Sequence<Ytilda, Xtilda>,
|
||||
Sequence<ytilda, xtilda>,
|
||||
Sequence<ytilda + 1, xtilda + 1>>{}),
|
||||
PassThrough<HTildaSlice>{},
|
||||
PassThrough<WTildaSlice>{},
|
||||
Slice<Sequence<YTilda, XTilda>,
|
||||
Sequence<iYTilda, iXTilda>,
|
||||
Sequence<iYTilda + 1, iXTilda + 1>>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
@@ -335,9 +330,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
|
||||
Sequence<2, 4>{}));
|
||||
|
||||
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_1_htildatrim_1_wtildatrim_global_desc,
|
||||
in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc,
|
||||
make_tuple(Merge<Sequence<C, 1, 1>>{},
|
||||
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
|
||||
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
@@ -349,17 +344,17 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
decltype(out_gemmk_gemmn_global_desc),
|
||||
decltype(in_gemmm_gemmn_global_desc),
|
||||
InMemoryDataOperation::none,
|
||||
InMemoryDataOperation::Set,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
|
||||
@@ -8,9 +8,10 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Number of GEMMs: YTilda * XTilda
|
||||
// GemmM = C
|
||||
// GemmN = N * Htilda * Wtilda;
|
||||
// GemmK = K * YdotNonZero * XdotNonZero
|
||||
// GemmN = N * HTildaSlice * WTildaSlice
|
||||
// GemmK = K * YDotSlice * XDotSlice
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
@@ -25,15 +26,15 @@ template <index_t GridSize,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMPerThread,
|
||||
index_t GemmNPerThread,
|
||||
index_t GemmKPerThread,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmThreadGemmDataPerReadM,
|
||||
index_t GemmThreadGemmDataPerReadN,
|
||||
index_t ThreadGemmAThreadCopySrcDataPerRead_GemmM,
|
||||
index_t ThreadGemmAThreadCopySrcDataPerRead_GemmN,
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmM,
|
||||
@@ -53,13 +54,90 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
return Ytilda * Xtilda;
|
||||
return YTilda * XTilda;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetGemmSizeImpl(index_t iYTilda, index_t iXTilda)
|
||||
{
|
||||
constexpr index_t N = InGlobalDesc::GetLengths()[0];
|
||||
constexpr index_t C = InGlobalDesc::GetLengths()[1];
|
||||
constexpr index_t Hi = InGlobalDesc::GetLengths()[2];
|
||||
constexpr index_t Wi = InGlobalDesc::GetLengths()[3];
|
||||
|
||||
constexpr index_t K = OutGlobalDesc::GetLengths()[1];
|
||||
constexpr index_t Ho = OutGlobalDesc::GetLengths()[2];
|
||||
constexpr index_t Wo = OutGlobalDesc::GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = WeiGlobalDesc::GetLengths()[2];
|
||||
constexpr index_t X = WeiGlobalDesc::GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t HTilda =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
|
||||
constexpr index_t iHTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t iWTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t iHTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t iWTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft;
|
||||
constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft;
|
||||
|
||||
// GemmM and GemmN
|
||||
constexpr index_t GemmM = C;
|
||||
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
|
||||
|
||||
// GemmK is different for each GEMM
|
||||
index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot;
|
||||
index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot;
|
||||
|
||||
index_t GemmK = K * YDotSlice * XDotSlice;
|
||||
|
||||
return Array<index_t, 3>{GemmM, GemmN, GemmK};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetGemmSize(index_t gemm_id)
|
||||
{
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
index_t iYTilda = gemm_id / XTilda;
|
||||
index_t iXTilda = gemm_id % XTilda;
|
||||
|
||||
return GetGemmSizeImpl(iYTilda, iXTilda);
|
||||
}
|
||||
|
||||
template <index_t iYTilda, index_t iXTilda>
|
||||
@@ -89,44 +167,39 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
#if 0 // debug
|
||||
// sanity-check for vectorized memory load
|
||||
// TODO: this logic may not be correct for bwd-data
|
||||
static_assert(
|
||||
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
|
||||
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
#endif
|
||||
//\todo static_assert for global vector load/store
|
||||
// statc_assert();
|
||||
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t Htilda =
|
||||
constexpr index_t HTilda =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda =
|
||||
constexpr index_t WTilda =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
|
||||
constexpr index_t iHTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t iWTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
constexpr index_t iHTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t iWTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft;
|
||||
constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft;
|
||||
|
||||
constexpr bool wei_skip_all_out_of_bound_check = true;
|
||||
// weight out-of-bound check can be skipped
|
||||
constexpr bool wei_skip_out_of_bound_check = true;
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
|
||||
@@ -134,20 +207,22 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Y,
|
||||
Sequence<Ydot, Ytilda>,
|
||||
Sequence<ConvStrideH / gcd_stride_dilation_h, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{},
|
||||
Sequence<YDot, YTilda>,
|
||||
Sequence<ConvStrideH / GcdStrideDilationH, 1, 0>,
|
||||
wei_skip_out_of_bound_check>{},
|
||||
Embed<X,
|
||||
Sequence<Xdot, Xtilda>,
|
||||
Sequence<ConvStrideW / gcd_stride_dilation_w, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{}),
|
||||
Sequence<XDot, XTilda>,
|
||||
Sequence<ConvStrideW / GcdStrideDilationW, 1, 0>,
|
||||
wei_skip_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
#if 1 // debug
|
||||
constexpr bool out_skip_all_out_of_bound_check = false;
|
||||
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK
|
||||
constexpr bool out_skip_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool out_skip_all_out_of_bound_check = true;
|
||||
//\todo sometimes output tensor out-of-bound check can be skipped, find out all such
|
||||
// situations
|
||||
constexpr bool out_skip_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
// output tensor
|
||||
@@ -156,35 +231,36 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
Embed<Ho,
|
||||
Sequence<Ydot, Htilda>,
|
||||
Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{},
|
||||
Sequence<YDot, HTilda>,
|
||||
Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>,
|
||||
out_skip_out_of_bound_check>{},
|
||||
Embed<Wo,
|
||||
Sequence<Xdot, Wtilda>,
|
||||
Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{}),
|
||||
Sequence<XDot, WTilda>,
|
||||
Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>,
|
||||
out_skip_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc =
|
||||
constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<Ytilda>{},
|
||||
PassThrough<Xtilda>{},
|
||||
Slice<Sequence<Htilda, Wtilda>,
|
||||
Sequence<HtildaLeft, WtildaLeft>,
|
||||
Sequence<HtildaRight, WtildaRight>>{}),
|
||||
PassThrough<YTilda>{},
|
||||
PassThrough<XTilda>{},
|
||||
Slice<Sequence<HTilda, WTilda>,
|
||||
Sequence<iHTildaLeft, iWTildaLeft>,
|
||||
Sequence<iHTildaRight, iWTildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
#if 1 // debug
|
||||
constexpr bool in_skip_all_out_of_bound_check = false;
|
||||
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK
|
||||
constexpr bool in_skip_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool in_skip_all_out_of_bound_check = true;
|
||||
//\todo sometimes input out-of-bound check can be skipped, find out all such situations
|
||||
constexpr bool in_skip_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
// input tensor
|
||||
@@ -193,7 +269,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
make_tuple(
|
||||
PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}),
|
||||
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
@@ -205,100 +281,96 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Hip,
|
||||
Sequence<Ytilda, Htilda>,
|
||||
Sequence<YTilda, HTilda>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>,
|
||||
in_skip_all_out_of_bound_check>{},
|
||||
in_skip_out_of_bound_check>{},
|
||||
Embed<Wip,
|
||||
Sequence<Xtilda, Wtilda>,
|
||||
Sequence<XTilda, WTilda>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>,
|
||||
in_skip_all_out_of_bound_check>{}),
|
||||
in_skip_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc =
|
||||
constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<Ytilda>{},
|
||||
PassThrough<Xtilda>{},
|
||||
Slice<Sequence<Htilda, Wtilda>,
|
||||
Sequence<HtildaLeft, WtildaLeft>,
|
||||
Sequence<HtildaRight, WtildaRight>>{}),
|
||||
PassThrough<YTilda>{},
|
||||
PassThrough<XTilda>{},
|
||||
Slice<Sequence<HTilda, WTilda>,
|
||||
Sequence<iHTildaLeft, iWTildaLeft>,
|
||||
Sequence<iHTildaRight, iWTildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
// GEMM
|
||||
constexpr index_t ytilda = iYTilda;
|
||||
constexpr index_t xtilda = iXTilda;
|
||||
|
||||
constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot;
|
||||
constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot;
|
||||
constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot;
|
||||
constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot;
|
||||
|
||||
// A matrix
|
||||
constexpr auto wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc =
|
||||
constexpr auto wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Slice<Sequence<Ydot, Xdot>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<YdotNonZero, XdotNonZero>>{},
|
||||
Slice<Sequence<Ytilda, Xtilda>,
|
||||
Sequence<ytilda, xtilda>,
|
||||
Sequence<ytilda + 1, xtilda + 1>>{}),
|
||||
make_tuple(
|
||||
PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{},
|
||||
Slice<Sequence<YTilda, XTilda>,
|
||||
Sequence<iYTilda, iXTilda>,
|
||||
Sequence<iYTilda + 1, iXTilda + 1>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{}, Merge<Sequence<C, 1, 1>>{}),
|
||||
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{}, Merge<Sequence<C, 1, 1>>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// B matrix
|
||||
constexpr auto out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc =
|
||||
constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<HtildaTrim>{},
|
||||
PassThrough<WtildaTrim>{},
|
||||
Slice<Sequence<Ydot, Xdot>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<YdotNonZero, XdotNonZero>>{}),
|
||||
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<HTildaSlice>{},
|
||||
PassThrough<WTildaSlice>{},
|
||||
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
|
||||
|
||||
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{},
|
||||
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
|
||||
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{},
|
||||
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C matrix
|
||||
constexpr auto in_n_c_1_htildatrim_1_wtildatrim_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<HtildaTrim>{},
|
||||
PassThrough<WtildaTrim>{},
|
||||
Slice<Sequence<Ytilda, Xtilda>,
|
||||
Sequence<ytilda, xtilda>,
|
||||
Sequence<ytilda + 1, xtilda + 1>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
|
||||
constexpr auto in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<HTildaSlice>{},
|
||||
PassThrough<WTildaSlice>{},
|
||||
Slice<Sequence<YTilda, XTilda>,
|
||||
Sequence<iYTilda, iXTilda>,
|
||||
Sequence<iYTilda + 1, iXTilda + 1>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
|
||||
|
||||
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_1_htildatrim_1_wtildatrim_global_desc,
|
||||
make_tuple(Merge<Sequence<C, 1, 1>>{}, Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
|
||||
in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc,
|
||||
make_tuple(Merge<Sequence<C, 1, 1>>{}, Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
@@ -310,19 +382,19 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
decltype(out_gemmk_gemmn_global_desc),
|
||||
decltype(in_gemmm_gemmn_global_desc),
|
||||
InMemoryDataOperation::none,
|
||||
InMemoryDataOperation::Set,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
ThreadGemmAThreadCopySrcDataPerRead_GemmM,
|
||||
ThreadGemmAThreadCopySrcDataPerRead_GemmN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<0, 1>,
|
||||
@@ -355,16 +427,16 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t iYTilda = GemmId / Xtilda;
|
||||
constexpr index_t iXTilda = GemmId % Xtilda;
|
||||
constexpr index_t iYTilda = GemmId / XTilda;
|
||||
constexpr index_t iXTilda = GemmId % XTilda;
|
||||
|
||||
static_assert(iYTilda < Ytilda && iXTilda < Xtilda, "wrong! iYtilda, iXtilda");
|
||||
static_assert(iYTilda < YTilda && iXTilda < XTilda, "wrong! iYtilda, iXtilda");
|
||||
|
||||
RunImpl<iYTilda, iXTilda>(p_in_global, p_wei_global, p_out_global);
|
||||
}
|
||||
|
||||
@@ -229,10 +229,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
3,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2,
|
||||
AddressSpace::global,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::lds,
|
||||
InMemoryDataOperation::none>(
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Lds,
|
||||
InMemoryDataOperation::Set>(
|
||||
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
|
||||
|
||||
// weight tensor
|
||||
@@ -269,10 +269,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
1,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
AddressSpace::global,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::lds,
|
||||
InMemoryDataOperation::none>(
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Lds,
|
||||
InMemoryDataOperation::Set>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
@@ -344,6 +344,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
|
||||
}
|
||||
|
||||
constexpr auto in_block_slice_copy_steps = Sequence<EPerBlock, 0, 0, 0>{};
|
||||
constexpr auto wei_block_slice_copy_steps = Sequence<EPerBlock, 0>{};
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
|
||||
e_block_data_begin += 2 * EPerBlock)
|
||||
@@ -366,8 +369,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
blockwise_in_copy.MoveSrcSliceWindow(in_block_slice_copy_steps, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(wei_block_slice_copy_steps, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -393,8 +396,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
blockwise_in_copy.MoveSrcSliceWindow(in_block_slice_copy_steps, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(wei_block_slice_copy_steps, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -482,14 +485,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::global,
|
||||
InMemoryDataOperation::none>({0, 0, 0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
0,
|
||||
b_thread_data_on_global,
|
||||
0})
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Global,
|
||||
InMemoryDataOperation::Set>({0, 0, 0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
0,
|
||||
b_thread_data_on_global,
|
||||
0})
|
||||
.Run(p_out_thread, p_out_global);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,9 +94,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_dep
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
integral_constant<AddressSpace, AddressSpace::Generic>{};
|
||||
constexpr auto global_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::global>{};
|
||||
integral_constant<AddressSpace, AddressSpace::Global>{};
|
||||
|
||||
static_assert(ConvDirection == ConvolutionDirection::Forward ||
|
||||
ConvDirection == ConvolutionDirection::BackwardWeight,
|
||||
@@ -141,13 +141,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_dep
|
||||
constexpr index_t E = C * Y * X;
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) &&
|
||||
(X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
static_assert(
|
||||
(Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) &&
|
||||
(X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
|
||||
"wrong! alignment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
@@ -357,37 +358,49 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_dep
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
// even iteration
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0);
|
||||
|
||||
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
if(has_two_iteration_left) // if has 2 iteration left
|
||||
{
|
||||
// even iteration
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
__syncthreads();
|
||||
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadThreadBuffer(
|
||||
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(
|
||||
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadThreadBuffer(
|
||||
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(
|
||||
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// odd iteration
|
||||
__syncthreads();
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
// odd iteration
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
}
|
||||
}
|
||||
|
||||
// copy output: register to global memory
|
||||
|
||||
@@ -25,15 +25,15 @@ template <index_t GridSize,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMPerThread,
|
||||
index_t GemmNPerThread,
|
||||
index_t GemmKPerThread,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmThreadGemmDataPerReadM,
|
||||
index_t GemmThreadGemmDataPerReadN,
|
||||
index_t ThreadGemmAThreadCopySrcDataPerRead_GemmM,
|
||||
index_t ThreadGemmAThreadCopySrcDataPerRead_GemmN,
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmK,
|
||||
@@ -130,19 +130,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(in_e_b_global_desc),
|
||||
decltype(out_k_b_global_desc),
|
||||
InMemoryDataOperation::none,
|
||||
InMemoryDataOperation::Set,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
ThreadGemmAThreadCopySrcDataPerRead_GemmM,
|
||||
ThreadGemmAThreadCopySrcDataPerRead_GemmN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<1, 0>,
|
||||
|
||||
@@ -251,9 +251,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.template Run<Float, AddressSpace::global>(p_in_global,
|
||||
blockwise_in_copy.template Run<Float, AddressSpace::Global>(p_in_global,
|
||||
p_in_block_double);
|
||||
blockwise_wei_copy.template Run<Float, AddressSpace::global>(p_wei_global,
|
||||
blockwise_wei_copy.template Run<Float, AddressSpace::Global>(p_wei_global,
|
||||
p_wei_block_double);
|
||||
}
|
||||
|
||||
@@ -285,9 +285,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
|
||||
blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
|
||||
p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
|
||||
blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
|
||||
p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
@@ -311,9 +311,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
|
||||
blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
|
||||
p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
|
||||
blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
|
||||
p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
@@ -390,7 +390,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_dep
|
||||
for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat)
|
||||
{
|
||||
threadwise_out_copy
|
||||
.template Run<Float, AddressSpace::generic, AddressSpace::global>(p_out_thread,
|
||||
.template Run<Float, AddressSpace::Generic, AddressSpace::Global>(p_out_thread,
|
||||
p_out_global);
|
||||
|
||||
threadwise_out_copy.MoveSrcSliceWindow(Sequence<0, 0, GemmNPerThreadSubC>{}, True);
|
||||
|
||||
@@ -60,7 +60,7 @@ __host__ __device__ constexpr auto
|
||||
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr auto
|
||||
make_ConstantMatrixDescriptor(ConstantTensorDescriptor_deprecated<Ts...>)
|
||||
make_ConstantMatrixDescriptor(ConstantTensorDescriptor_deprecated<Ts...>)
|
||||
{
|
||||
using TDesc = ConstantTensorDescriptor_deprecated<Ts...>;
|
||||
static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
|
||||
|
||||
@@ -267,7 +267,7 @@ struct TensorCoordinate
|
||||
private:
|
||||
template <typename... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>)
|
||||
{
|
||||
return NativeTensorCoordinate<NativeTensorDescriptor<Ts...>>(
|
||||
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
|
||||
@@ -275,7 +275,7 @@ struct TensorCoordinate
|
||||
|
||||
template <typename... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>)
|
||||
{
|
||||
return TransformedTensorCoordinate<TransformedTensorDescriptor<Ts...>>(
|
||||
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
|
||||
|
||||
@@ -327,14 +327,14 @@ struct TensorCoordinate_deprecated
|
||||
private:
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(ConstantTensorDescriptor_deprecated<Ts...>)
|
||||
MakeDummyTensorCoordinate(ConstantTensorDescriptor_deprecated<Ts...>)
|
||||
{
|
||||
return NormalTensorCoordinate_deprecated<ConstantTensorDescriptor_deprecated<Ts...>>();
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor_deprecated<Ts...>)
|
||||
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor_deprecated<Ts...>)
|
||||
{
|
||||
return MergedTensorCoordinate_deprecated<
|
||||
ConstantMergedTensorDescriptor_deprecated<Ts...>>();
|
||||
|
||||
@@ -64,10 +64,10 @@ template <typename LowerTensorDescriptor,
|
||||
index_t... LowerDimensionIds,
|
||||
index_t... UpperDimensionIds>
|
||||
__host__ __device__ constexpr auto
|
||||
reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor,
|
||||
Sequence<LowerLengths...>,
|
||||
Sequence<LowerDimensionIds...>,
|
||||
Sequence<UpperDimensionIds...>)
|
||||
reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor,
|
||||
Sequence<LowerLengths...>,
|
||||
Sequence<LowerDimensionIds...>,
|
||||
Sequence<UpperDimensionIds...>)
|
||||
{
|
||||
return TransformedTensorDescriptor<LowerTensorDescriptor,
|
||||
Tuple<PassThrough<LowerLengths>...>,
|
||||
@@ -78,7 +78,7 @@ reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor,
|
||||
// reorder a NativeTensorDescriptor
|
||||
template <typename... Ts, typename MapLower2Upper>
|
||||
__host__ __device__ constexpr auto
|
||||
reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLower2Upper)
|
||||
reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLower2Upper)
|
||||
{
|
||||
static_assert(is_valid_sequence_map<MapLower2Upper>{},
|
||||
"wrong! MapLower2Upper is not a valid map");
|
||||
@@ -96,7 +96,7 @@ reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLo
|
||||
// reorder a TransformedTensorDescriptor
|
||||
template <typename... Ts, typename MapLower2Upper>
|
||||
__host__ __device__ constexpr auto
|
||||
reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor<Ts...>, MapLower2Upper)
|
||||
reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor<Ts...>, MapLower2Upper)
|
||||
{
|
||||
static_assert(is_valid_sequence_map<MapLower2Upper>{},
|
||||
"wrong! MapLower2Upper is not a valid map");
|
||||
@@ -152,9 +152,9 @@ __host__ __device__ constexpr auto unfold_tensor_descriptor(NativeTensorDescript
|
||||
typename arithmetic_sequence_gen<FirstUnfoldDim, LastUnfoldDim + 1, 1>::type{};
|
||||
constexpr auto right = typename arithmetic_sequence_gen<LastUnfoldDim + 1, nDim, 1>::type{};
|
||||
|
||||
// sanity-checknfoldable
|
||||
// sanity-check if unfold-able
|
||||
static_assert(are_dimensions_unfoldable(desc.GetLengths(middle), desc.GetStrides(middle)),
|
||||
"wrong! not unfoldable");
|
||||
"wrong! not unfold-able");
|
||||
|
||||
// unfolded length, stride
|
||||
constexpr index_t unfold_length =
|
||||
|
||||
@@ -23,8 +23,8 @@ template <index_t BlockSize,
|
||||
index_t MLevel1ThreadCluster,
|
||||
index_t NLevel1ThreadCluster,
|
||||
index_t KPerThreadLoop,
|
||||
index_t DataPerReadA,
|
||||
index_t DataPerReadB>
|
||||
index_t ThreadGemmADataPerRead_M,
|
||||
index_t ThreadGemmBDataPerRead_N>
|
||||
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
{
|
||||
struct MatrixIndex
|
||||
@@ -150,13 +150,13 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
decltype(a_thread_mtx),
|
||||
KPerThreadLoop,
|
||||
MPerThreadSubC,
|
||||
DataPerReadA>{};
|
||||
ThreadGemmADataPerRead_M>{};
|
||||
|
||||
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB,
|
||||
decltype(b_thread_mtx),
|
||||
KPerThreadLoop,
|
||||
NPerThreadSubC,
|
||||
DataPerReadB>{};
|
||||
ThreadGemmBDataPerRead_N>{};
|
||||
|
||||
constexpr auto threadwise_gemm =
|
||||
ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_mtx),
|
||||
@@ -238,13 +238,13 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
decltype(a_thread_mtx),
|
||||
KPerThreadLoop,
|
||||
MPerThreadSubC,
|
||||
DataPerReadA>{};
|
||||
ThreadGemmADataPerRead_M>{};
|
||||
|
||||
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB,
|
||||
decltype(b_thread_mtx),
|
||||
KPerThreadLoop,
|
||||
NPerThreadSubC,
|
||||
DataPerReadB>{};
|
||||
ThreadGemmBDataPerRead_N>{};
|
||||
|
||||
constexpr auto threadwise_gemm =
|
||||
ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_sub_mtx),
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// This threadwise copy allow vector access of src and dst.
|
||||
// This blockwise copy allow vector access of src and dst.
|
||||
// It allows the vector size to be different on src and dst.
|
||||
// The dimension of vector access can be different for src and dst.
|
||||
// The dimension access order can be different for src and dst.
|
||||
@@ -28,10 +28,10 @@ template <index_t BlockSize,
|
||||
index_t DstVectorWriteDim,
|
||||
index_t SrcDataPerRead,
|
||||
index_t DstDataPerWrite,
|
||||
AddressSpace SrcAddressSpace = AddressSpace::generic,
|
||||
AddressSpace ThreadBufferAddressSpace = AddressSpace::generic,
|
||||
AddressSpace DstAddressSpace = AddressSpace::generic,
|
||||
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::none>
|
||||
AddressSpace SrcAddressSpace = AddressSpace::Generic,
|
||||
AddressSpace ThreadBufferAddressSpace = AddressSpace::Generic,
|
||||
AddressSpace DstAddressSpace = AddressSpace::Generic,
|
||||
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set>
|
||||
struct BlockwiseGenericTensorSliceCopy_v4
|
||||
{
|
||||
static constexpr index_t nDim = BlockSrcDesc::GetNumOfDimension();
|
||||
@@ -115,7 +115,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
template <typename BlockSrcData, typename BlockDstData>
|
||||
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const
|
||||
{
|
||||
static_assert(ThreadBufferAddressSpace == AddressSpace::vgpr,
|
||||
static_assert(ThreadBufferAddressSpace == AddressSpace::Vgpr,
|
||||
"wrong! This function use vgpr as its thread "
|
||||
"buffer. However, you have set RunLoadThreadBuffer and RunStoreThreadBuffer "
|
||||
"to use ThreadBufferAddressSpace as their thread buffer, which is not vgpr. "
|
||||
@@ -157,7 +157,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
1,
|
||||
SrcAddressSpace,
|
||||
ThreadBufferAddressSpace,
|
||||
InMemoryDataOperation::none>;
|
||||
InMemoryDataOperation::Set>;
|
||||
|
||||
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2<ThreadBufferDesc,
|
||||
BlockDstDesc,
|
||||
|
||||
@@ -499,7 +499,7 @@ struct BlockwiseGenericTensorSliceCopy_v2_deprecated
|
||||
ThreadBufferData* p_thread_buffer) const
|
||||
{
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
integral_constant<AddressSpace, AddressSpace::Generic>{};
|
||||
|
||||
RunLoadThreadBuffer(
|
||||
p_block_src, p_thread_buffer, generic_address_space, generic_address_space);
|
||||
@@ -529,7 +529,7 @@ struct BlockwiseGenericTensorSliceCopy_v2_deprecated
|
||||
BlockDstData* p_block_dst) const
|
||||
{
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
integral_constant<AddressSpace, AddressSpace::Generic>{};
|
||||
|
||||
RunStoreThreadBuffer(
|
||||
p_thread_buffer, p_block_dst, generic_address_space, generic_address_space);
|
||||
@@ -548,7 +548,7 @@ struct BlockwiseGenericTensorSliceCopy_v2_deprecated
|
||||
BlockSrcData p_thread_buffer[GetThreadBufferSize()];
|
||||
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
integral_constant<AddressSpace, AddressSpace::Generic>{};
|
||||
|
||||
RunLoadThreadBuffer(
|
||||
p_block_src, p_thread_buffer, block_src_address_space, generic_address_space);
|
||||
@@ -562,7 +562,7 @@ struct BlockwiseGenericTensorSliceCopy_v2_deprecated
|
||||
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const
|
||||
{
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
integral_constant<AddressSpace, AddressSpace::Generic>{};
|
||||
|
||||
Run(p_block_src, p_block_dst, generic_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
@@ -22,15 +22,15 @@ template <index_t GridSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerThreadSubC,
|
||||
index_t NPerThreadSubC,
|
||||
index_t MPerThread,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t MLevel0Cluster,
|
||||
index_t NLevel0Cluster,
|
||||
index_t MLevel1Cluster,
|
||||
index_t NLevel1Cluster,
|
||||
index_t KPerThreadLoop,
|
||||
index_t ThreadGemmDataPerReadM,
|
||||
index_t ThreadGemmDataPerReadN,
|
||||
index_t ThreadGemmAThreadCopySrcDataPerRead_M,
|
||||
index_t ThreadGemmBThreadCopySrcDataPerRead_N,
|
||||
typename ABlockCopyThreadSliceLengths_K_M,
|
||||
typename ABlockCopyThreadClusterLengths_K_M,
|
||||
typename ABlockCopyThreadClusterArrangeOrder,
|
||||
@@ -54,8 +54,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
{
|
||||
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
ThreadGemmDataPerReadM,
|
||||
ThreadGemmDataPerReadN);
|
||||
ThreadGemmAThreadCopySrcDataPerRead_M,
|
||||
ThreadGemmBThreadCopySrcDataPerRead_N);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
@@ -101,8 +101,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
// lds max alignment
|
||||
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
ThreadGemmDataPerReadM,
|
||||
ThreadGemmDataPerReadN);
|
||||
ThreadGemmAThreadCopySrcDataPerRead_M,
|
||||
ThreadGemmBThreadCopySrcDataPerRead_N);
|
||||
|
||||
// divide block work by [M, N]
|
||||
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
|
||||
@@ -139,10 +139,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
1,
|
||||
ABlockCopySrcDataPerRead,
|
||||
ABlockCopyDstDataPerWrite_M,
|
||||
AddressSpace::global,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::lds,
|
||||
InMemoryDataOperation::none>(
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Lds,
|
||||
InMemoryDataOperation::Set>(
|
||||
{0, m_block_data_on_global}, {0, 0});
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
@@ -165,10 +165,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
1,
|
||||
BBlockCopySrcDataPerRead,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
AddressSpace::global,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::lds,
|
||||
InMemoryDataOperation::none>(
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Lds,
|
||||
InMemoryDataOperation::Set>(
|
||||
{0, n_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
@@ -181,35 +181,33 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc);
|
||||
|
||||
// sanity check
|
||||
static_assert(MPerBlock % (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster) == 0 &&
|
||||
NPerBlock % (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster) == 0,
|
||||
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
|
||||
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
MPerBlock / (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster);
|
||||
constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
|
||||
|
||||
constexpr index_t GemmNRepeat =
|
||||
NPerBlock / (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster);
|
||||
constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * MPerThreadSubC>{}, Number<GemmNRepeat * NPerThreadSubC>{});
|
||||
Number<GemmMRepeat * MPerThread>{}, Number<GemmNRepeat * NPerThread>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_k_m_block_mtx_desc),
|
||||
decltype(b_k_n_block_mtx_desc),
|
||||
decltype(c_m0m1_n0n1_thread_mtx_desc),
|
||||
MPerThreadSubC,
|
||||
NPerThreadSubC,
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
MLevel0Cluster,
|
||||
NLevel0Cluster,
|
||||
MLevel1Cluster,
|
||||
NLevel1Cluster,
|
||||
KPerThreadLoop,
|
||||
ThreadGemmDataPerReadM,
|
||||
ThreadGemmDataPerReadN>{};
|
||||
KPerThread,
|
||||
ThreadGemmAThreadCopySrcDataPerRead_M,
|
||||
ThreadGemmBThreadCopySrcDataPerRead_N>{};
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr index_t a_block_space =
|
||||
@@ -233,6 +231,9 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
b_blockwise_copy.Run(p_b_global, p_b_block_double);
|
||||
}
|
||||
|
||||
constexpr auto a_block_slice_copy_steps = Sequence<KPerBlock, 0>{};
|
||||
constexpr auto b_block_slice_copy_steps = Sequence<KPerBlock, 0>{};
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
|
||||
k_block_data_begin += 2 * KPerBlock)
|
||||
@@ -255,8 +256,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
|
||||
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -282,8 +283,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
|
||||
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -317,16 +318,16 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
|
||||
// input: register to global memory
|
||||
{
|
||||
constexpr index_t M1 = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t M0 = M / M1;
|
||||
|
||||
constexpr index_t N1 = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t N0 = N / N1;
|
||||
|
||||
// define input tensor descriptor for threadwise copy
|
||||
// thread input tensor, src of threadwise copy
|
||||
constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed(
|
||||
Sequence<GemmMRepeat, MPerThreadSubC, GemmNRepeat, NPerThreadSubC>{});
|
||||
Sequence<GemmMRepeat, MPerThread, GemmNRepeat, NPerThread>{});
|
||||
|
||||
constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor(
|
||||
c_m_n_global_desc,
|
||||
@@ -352,8 +353,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
CThreadCopySrcDstVectorReadWriteDim,
|
||||
1,
|
||||
CThreadCopyDstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Global,
|
||||
CGlobalMemoryDataOperation>(
|
||||
{0, 0, 0, 0},
|
||||
{m_thread_data_on_global / M1,
|
||||
|
||||
@@ -21,9 +21,9 @@ template <typename SrcDesc,
|
||||
index_t SrcDstVectorReadWriteDim,
|
||||
index_t SrcDataPerRead,
|
||||
index_t DstDataPerWrite,
|
||||
AddressSpace SrcAddressSpace = AddressSpace::generic,
|
||||
AddressSpace DstAddressSpace = AddressSpace::generic,
|
||||
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::none>
|
||||
AddressSpace SrcAddressSpace = AddressSpace::Generic,
|
||||
AddressSpace DstAddressSpace = AddressSpace::Generic,
|
||||
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set>
|
||||
struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
@@ -115,8 +115,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
transfer_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(
|
||||
AddressSpace::Vgpr,
|
||||
InMemoryDataOperation::Set>(
|
||||
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
|
||||
}
|
||||
}
|
||||
@@ -146,7 +146,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
{
|
||||
transfer_data<DstData,
|
||||
DstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::Vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(
|
||||
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
|
||||
@@ -265,12 +265,12 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
transfer_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(p_src,
|
||||
src_nonlinear_coord.GetOffset() +
|
||||
src_linear_offset,
|
||||
p_src_long_vector,
|
||||
buffer_offset);
|
||||
AddressSpace::Vgpr,
|
||||
InMemoryDataOperation::Set>(p_src,
|
||||
src_nonlinear_coord.GetOffset() +
|
||||
src_linear_offset,
|
||||
p_src_long_vector,
|
||||
buffer_offset);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,7 +303,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
{
|
||||
transfer_data<DstData,
|
||||
DstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::Vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(
|
||||
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
|
||||
@@ -404,8 +404,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
transfer_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(
|
||||
AddressSpace::Vgpr,
|
||||
InMemoryDataOperation::Set>(
|
||||
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
|
||||
}
|
||||
}
|
||||
@@ -448,7 +448,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
{
|
||||
transfer_data<DstData,
|
||||
DstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::Vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(p_dst_long_vector,
|
||||
buffer_offset,
|
||||
|
||||
@@ -333,7 +333,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated
|
||||
// 2. src_normal_offset must be calculatd at compile time (guaranteed by
|
||||
// algorithm)
|
||||
// 3. src_merged_offset can be runtime value (no assumption imposed)
|
||||
static_if<SrcAddressSpace == AddressSpace::global>{}([&](auto fwd) {
|
||||
static_if<SrcAddressSpace == AddressSpace::Global>{}([&](auto fwd) {
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
vector_data = amd_intrinsic_buffer_load<SrcData, SrcDataPerAccess>(
|
||||
fwd(p_src), src_merged_offset, src_normal_offset);
|
||||
@@ -442,7 +442,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated
|
||||
// 2. dst_normal_offset must be calculatd at compile time (guaranteed by
|
||||
// algorithm)
|
||||
// 3. dst_merged_offset can be runtime value (no assumption imposed)
|
||||
static_if<DstAddressSpace == AddressSpace::global>{}([&](auto fwd) {
|
||||
static_if<DstAddressSpace == AddressSpace::Global>{}([&](auto fwd) {
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
amd_intrinsic_buffer_store<DstData, DstDataPerAccess>(
|
||||
vector_data, fwd(p_dst), dst_merged_offset, dst_normal_offset);
|
||||
@@ -464,7 +464,7 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated
|
||||
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
|
||||
{
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
integral_constant<AddressSpace, AddressSpace::Generic>{};
|
||||
|
||||
Run(p_src, p_dst, generic_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
@@ -54,21 +54,23 @@
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
|
||||
#define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK 0
|
||||
#define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK 0
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum AddressSpace
|
||||
{
|
||||
generic,
|
||||
global,
|
||||
lds,
|
||||
vgpr
|
||||
Generic,
|
||||
Global,
|
||||
Lds,
|
||||
Vgpr
|
||||
};
|
||||
|
||||
enum InMemoryDataOperation
|
||||
{
|
||||
none,
|
||||
atomic_add
|
||||
Set,
|
||||
AtomicAdd
|
||||
};
|
||||
|
||||
#if CK_UNSIGNED_INDEX_TYPE
|
||||
|
||||
@@ -10,13 +10,14 @@ template <typename T,
|
||||
index_t DataPerAccess,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace>
|
||||
__device__ void copy_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
__device__ void set_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
{
|
||||
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
// TODO: use static_if::ElseIf, instead of nested static_if
|
||||
static_if<SrcAddressSpace == AddressSpace::global && DstAddressSpace == vgpr>{}([&](auto) {
|
||||
static_if<SrcAddressSpace == AddressSpace::Global &&
|
||||
DstAddressSpace == AddressSpace::Vgpr>{}([&](auto) {
|
||||
// buffer_load requires:
|
||||
// 1) p_src must be in global memory space, d_dst must be vgpr
|
||||
// 2) p_src to be a block-invariant pointer.
|
||||
@@ -24,7 +25,8 @@ __device__ void copy_data(const T* p_src, index_t src_offset, T* p_dst, index_t
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
|
||||
amd_intrinsic_buffer_load<T, DataPerAccess>(p_src, src_offset, 0);
|
||||
}).Else([&](auto) {
|
||||
static_if<SrcAddressSpace == AddressSpace::vgpr && DstAddressSpace == global>{}([&](auto) {
|
||||
static_if<SrcAddressSpace == AddressSpace::Vgpr &&
|
||||
DstAddressSpace == AddressSpace::Global>{}([&](auto) {
|
||||
// buffer_store requires:
|
||||
// 1) p_src must be in vgpr space, d_dst must be global memory
|
||||
// 2) p_dst to be a block-invariant pointer.
|
||||
@@ -50,19 +52,18 @@ __device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, in
|
||||
{
|
||||
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
|
||||
|
||||
static_if<SrcAddressSpace == AddressSpace::vgpr && DstAddressSpace == AddressSpace::global>{}(
|
||||
[&](auto) {
|
||||
static_if<SrcAddressSpace == AddressSpace::Vgpr &&
|
||||
DstAddressSpace == AddressSpace::Global>{}([&](auto) {
|
||||
#if CK_USE_AMD_BUFFER_ATOMIC_ADD
|
||||
amd_intrinsic_buffer_atomic_add<T, DataPerAccess>(
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]), p_dst, dst_offset, 0);
|
||||
amd_intrinsic_buffer_atomic_add<T, DataPerAccess>(
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]), p_dst, dst_offset, 0);
|
||||
#else
|
||||
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
|
||||
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
|
||||
#endif
|
||||
})
|
||||
.Else([&](auto fwd) {
|
||||
static_assert(fwd(false), "atomic_add doesn't support this memory space");
|
||||
});
|
||||
}).Else([&](auto fwd) {
|
||||
static_assert(fwd(false), "atomic_add doesn't support this memory space");
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
@@ -72,17 +73,17 @@ template <typename T,
|
||||
InMemoryDataOperation DstInMemOp>
|
||||
__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
{
|
||||
static_assert(DstInMemOp == InMemoryDataOperation::none ||
|
||||
DstInMemOp == InMemoryDataOperation::atomic_add,
|
||||
static_assert(DstInMemOp == InMemoryDataOperation::Set ||
|
||||
DstInMemOp == InMemoryDataOperation::AtomicAdd,
|
||||
"wrong! InMemoryDataOperation not supported!");
|
||||
|
||||
// TODO: use static_if::ElseIf
|
||||
static_if<DstInMemOp == InMemoryDataOperation::none>{}([&](auto) {
|
||||
copy_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
|
||||
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
|
||||
set_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, p_dst, dst_offset);
|
||||
});
|
||||
|
||||
static_if<DstInMemOp == InMemoryDataOperation::atomic_add>{}([&](auto) {
|
||||
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
|
||||
atomic_add_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, p_dst, dst_offset);
|
||||
});
|
||||
|
||||
@@ -23,14 +23,13 @@ __device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, in
|
||||
{
|
||||
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
|
||||
|
||||
static_if<SrcAddressSpace == AddressSpace::vgpr && DstAddressSpace == AddressSpace::global>{}(
|
||||
[&](auto) {
|
||||
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
|
||||
})
|
||||
.Else([&](auto fwd) {
|
||||
static_assert(fwd(false), "atomic_add doesn't support this memory space");
|
||||
});
|
||||
static_if<SrcAddressSpace == AddressSpace::Vgpr &&
|
||||
DstAddressSpace == AddressSpace::Global>{}([&](auto) {
|
||||
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
|
||||
}).Else([&](auto fwd) {
|
||||
static_assert(fwd(false), "atomic_add doesn't support this memory space");
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
@@ -40,17 +39,17 @@ template <typename T,
|
||||
InMemoryDataOperation DstInMemOp>
|
||||
__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
{
|
||||
static_assert(DstInMemOp == InMemoryDataOperation::none ||
|
||||
DstInMemOp == InMemoryDataOperation::atomic_add,
|
||||
static_assert(DstInMemOp == InMemoryDataOperation::Set ||
|
||||
DstInMemOp == InMemoryDataOperation::AtomicAdd,
|
||||
"wrong! InMemoryDataOperation not supported!");
|
||||
|
||||
// TODO: use static_if::ElseIf
|
||||
static_if<DstInMemOp == InMemoryDataOperation::none>{}([&](auto) {
|
||||
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
|
||||
copy_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, p_dst, dst_offset);
|
||||
});
|
||||
|
||||
static_if<DstInMemOp == InMemoryDataOperation::atomic_add>{}([&](auto) {
|
||||
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
|
||||
atomic_add_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, p_dst, dst_offset);
|
||||
});
|
||||
|
||||
@@ -107,27 +107,22 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T gcd(T x, T y)
|
||||
{
|
||||
if(x == 0)
|
||||
if(x == y || x == 0)
|
||||
{
|
||||
return y;
|
||||
}
|
||||
|
||||
if(y == 0)
|
||||
else if(y == 0)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
if(x == y)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
if(x > y)
|
||||
else if(x > y)
|
||||
{
|
||||
return gcd(x - y, y);
|
||||
}
|
||||
|
||||
return gcd(x, y - x);
|
||||
else
|
||||
{
|
||||
return gcd(x, y - x);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
@@ -150,10 +145,10 @@ __host__ __device__ constexpr T lcm(T x, T y)
|
||||
return (x * y) / gcd(x, y);
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename... Zs>
|
||||
__host__ __device__ constexpr auto lcm(X x, Y y, Zs... zs)
|
||||
template <typename X, typename... Ys>
|
||||
__host__ __device__ constexpr auto lcm(X x, Ys... ys)
|
||||
{
|
||||
return lcm(x, lcm(y, zs...));
|
||||
return lcm(x, lcm(ys...));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
|
||||
@@ -49,20 +49,20 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
@@ -79,6 +79,36 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#endif
|
||||
|
||||
constexpr index_t GemmM = C * Y * X;
|
||||
@@ -104,13 +134,13 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
|
||||
@@ -66,13 +66,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
@@ -96,13 +96,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
@@ -127,13 +127,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
@@ -152,33 +152,33 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#endif
|
||||
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
constexpr index_t HTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
constexpr index_t HTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
|
||||
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
|
||||
|
||||
constexpr index_t GemmM = C * Ytilda * Xtilda;
|
||||
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim;
|
||||
constexpr index_t GemmM = C * YTilda * XTilda;
|
||||
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
@@ -200,13 +200,13 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
|
||||
@@ -66,13 +66,13 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
@@ -91,33 +91,33 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
constexpr index_t HTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
constexpr index_t HTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
|
||||
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
|
||||
|
||||
constexpr index_t GemmM = C;
|
||||
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim;
|
||||
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
@@ -139,13 +139,13 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
|
||||
@@ -69,13 +69,13 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
@@ -99,13 +99,13 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
@@ -124,33 +124,33 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
constexpr index_t HTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
constexpr index_t HTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
|
||||
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
|
||||
|
||||
constexpr index_t GemmM = C;
|
||||
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim;
|
||||
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
@@ -159,7 +159,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
using GridwiseConv = GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw<
|
||||
using GridwiseConvBwdData = GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
@@ -174,13 +174,13 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
@@ -196,21 +196,29 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
|
||||
static_for<0, GridwiseConv::GetNumberOfGemm(), 1>{}([&](auto gemm_id_) {
|
||||
static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id_) {
|
||||
constexpr index_t gemm_id = decltype(gemm_id_){};
|
||||
|
||||
launch_kernel(run_gridwise_convolution_backward_data_v4r1<GridwiseConv,
|
||||
gemm_id,
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id);
|
||||
constexpr index_t gemm_k = gemm_sizes.At(2);
|
||||
constexpr bool is_gemm_not_empty = gemm_k > 0;
|
||||
|
||||
// only compile and run if GEMM is no empty
|
||||
static_if<is_gemm_not_empty>{}([&](auto fwd) {
|
||||
launch_kernel(
|
||||
run_gridwise_convolution_backward_data_v4r1<GridwiseConvBwdData,
|
||||
fwd(gemm_id),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
});
|
||||
});
|
||||
|
||||
timer.End();
|
||||
|
||||
@@ -23,17 +23,16 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace launcher;
|
||||
|
||||
#if 0
|
||||
// 3x3 filter, 2x2 stride, 35x35 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
#if 1
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
@@ -158,7 +157,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<2, 2>;
|
||||
using RightPads = Sequence<2, 2>;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 1x7 filter, 0x3 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
@@ -188,7 +187,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
|
||||
Reference in New Issue
Block a user