mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Added bwd data v3r1 v4r1, tweaking v1 (#10)
* Added bwd data v3r1: breaking down compute into a series of load balanced GEMM, and launch in a single kernel
* Added bwd data v4r1: like v3r1, but launch GEMMs in multiple kernels
* Tweaked v1r1 and v1r2 (atomic) on AMD GPU
[ROCm/composable_kernel commit: c5da0377fb]
This commit is contained in:
@@ -111,8 +111,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
Embed<Hi + InLeftPads::At(0) + InRightPads::At(0),
|
||||
Sequence<Y, Ho>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Wi + InLeftPads::At(1) + InRightPads::At(1),
|
||||
Sequence<X, Wo>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
@@ -122,7 +126,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// GEMM: atomic add
|
||||
// GEMM
|
||||
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
|
||||
? InMemoryDataOperation::none
|
||||
: InMemoryDataOperation::atomic_add;
|
||||
|
||||
constexpr auto gridwise_gemm =
|
||||
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
|
||||
BlockSize,
|
||||
@@ -131,7 +139,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
decltype(wei_k_e_global_desc),
|
||||
decltype(out_k_b_global_desc),
|
||||
decltype(in_e_b_global_desc),
|
||||
InMemoryDataOperation::atomic_add,
|
||||
in_memory_op,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
|
||||
@@ -352,8 +352,16 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
|
||||
}
|
||||
}
|
||||
|
||||
// input: register to global memory, atomic add
|
||||
{
|
||||
#if 1 // debug
|
||||
// input: register to global memory, atomic add
|
||||
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
|
||||
? InMemoryDataOperation::none
|
||||
: InMemoryDataOperation::atomic_add;
|
||||
#else
|
||||
constexpr auto in_memory_op = InMemoryDataOperation::atomic_add;
|
||||
#endif
|
||||
|
||||
constexpr index_t E1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t E0 = E / E1;
|
||||
|
||||
@@ -378,8 +386,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(UnMerge<Sequence<N0, N1>>{},
|
||||
UnMerge<Sequence<C0, C1>>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
Embed<Hi + LeftPads::At(0) + RightPads::At(0),
|
||||
Sequence<Y, Ho>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Wi + LeftPads::At(1) + RightPads::At(1),
|
||||
Sequence<X, Wo>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
|
||||
|
||||
@@ -422,13 +434,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
|
||||
InThreadCopyDstDataPerWrite_B,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::global,
|
||||
InMemoryDataOperation::atomic_add>({0, 0, 0, 0, 0, 0},
|
||||
{e_thread_data_on_global / E1,
|
||||
e_thread_data_on_global % E1,
|
||||
0,
|
||||
b_thread_data_on_global / B1,
|
||||
b_thread_data_on_global % B1,
|
||||
0})
|
||||
in_memory_op>({0, 0, 0, 0, 0, 0},
|
||||
{e_thread_data_on_global / E1,
|
||||
e_thread_data_on_global % E1,
|
||||
0,
|
||||
b_thread_data_on_global / B1,
|
||||
b_thread_data_on_global % B1,
|
||||
0})
|
||||
.Run(p_in_thread, p_in_global);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
namespace ck {
|
||||
|
||||
// GemmM = C * Ytilda * Xtilda;
|
||||
// GemmN = N * Htilda * Wtilda;
|
||||
// GemmN = N * HtildaNonZero * WtildaNonZero;
|
||||
// GemmK = K * Ydot * Xdot;
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
@@ -20,8 +20,8 @@ template <index_t GridSize,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InputLeftPads,
|
||||
typename InputRightPads,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
@@ -71,6 +71,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
#if 0 // debug
|
||||
// sanity-check for vectorized memory load
|
||||
// TODO: this logic may not be correct for bwd-data
|
||||
static_assert(
|
||||
@@ -78,6 +79,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
@@ -88,30 +90,34 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
|
||||
constexpr index_t right_pad_ho = (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda);
|
||||
constexpr index_t right_pad_wo = (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda);
|
||||
constexpr index_t Htilda =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t Htilda = Ho + right_pad_ho;
|
||||
constexpr index_t Wtilda = Wo + right_pad_wo;
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_k_c_yp_xp_global_desc = transform_tensor_descriptor(
|
||||
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_y_x_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Y, X>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<Ydot * Ytilda - Y, Xdot * Xtilda - X>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_yp_xp_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Ydot, Ytilda>,
|
||||
Embed<Y,
|
||||
Sequence<Ydot, Ytilda>,
|
||||
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>>{},
|
||||
Embed<Sequence<Xdot, Xtilda>,
|
||||
Embed<X,
|
||||
Sequence<Xdot, Xtilda>,
|
||||
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
@@ -123,55 +129,96 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_n_k_hop_wop_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
Pad<Sequence<Ho, Wo>, Sequence<0, 0>, Sequence<right_pad_ho, right_pad_wo>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_hop_wop_global_desc,
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
Embed<Sequence<Ydot, Htilda>,
|
||||
Embed<Ho,
|
||||
Sequence<Ydot, Htilda>,
|
||||
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{},
|
||||
Embed<Sequence<Xdot, Wtilda>,
|
||||
Embed<Wo,
|
||||
Sequence<Xdot, Wtilda>,
|
||||
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
|
||||
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<Ytilda>{},
|
||||
PassThrough<Xtilda>{},
|
||||
Slice<Sequence<Htilda, Wtilda>,
|
||||
Sequence<HtildaLeft, WtildaLeft>,
|
||||
Sequence<HtildaRight, WtildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
constexpr auto out_gemmk_gemmn_global_desc =
|
||||
transform_tensor_descriptor(out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
|
||||
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{},
|
||||
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
#if 1 // debug
|
||||
constexpr bool in_skip_all_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool in_skip_all_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
// input tensor
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, InputLeftPads, InputRightPads>{}),
|
||||
make_tuple(
|
||||
PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Ytilda, Htilda>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<Xtilda, Wtilda>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
Embed<Hip,
|
||||
Sequence<Ytilda, Htilda>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>,
|
||||
in_skip_all_out_of_bound_check>{},
|
||||
Embed<Wip,
|
||||
Sequence<Xtilda, Wtilda>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>,
|
||||
in_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<Ytilda>{},
|
||||
PassThrough<Xtilda>{},
|
||||
Slice<Sequence<Htilda, Wtilda>,
|
||||
Sequence<HtildaLeft, WtildaLeft>,
|
||||
Sequence<HtildaRight, WtildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
constexpr auto in_gemmm_gemmn_global_desc =
|
||||
transform_tensor_descriptor(in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{},
|
||||
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// GEMM
|
||||
constexpr auto gridwise_gemm =
|
||||
|
||||
@@ -0,0 +1,393 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Ytilda*Xtilda number of GEMMs
|
||||
// GemmM = C;
|
||||
// GemmN = N * HtildaNonZero * WtildaNonZero;
|
||||
// GemmK = K * YdotNonZero * XdotNonZero;
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmThreadGemmDataPerReadM,
|
||||
index_t GemmThreadGemmDataPerReadN,
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmM,
|
||||
index_t GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
index_t GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
|
||||
struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
|
||||
{
|
||||
// this is a hack, should query this info from gridwise_gemm instead of duplicate its logic
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr index_t max_lds_align = math::lcm(GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_gemmk_gemmm_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<GemmKPerBlock, GemmMPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_gemmk_gemmn_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<GemmKPerBlock, GemmNPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr index_t a_block_space =
|
||||
math::integer_least_multiple(a_gemmk_gemmm_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
constexpr index_t b_block_space =
|
||||
math::integer_least_multiple(b_gemmk_gemmn_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_space + b_block_space) * sizeof(Float);
|
||||
}
|
||||
|
||||
__device__ void Run(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
#if 0 // debug
|
||||
// sanity-check for vectorized memory load
|
||||
// TODO: this logic may not be correct for bwd-data
|
||||
static_assert(
|
||||
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
|
||||
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
|
||||
constexpr index_t Htilda =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
|
||||
constexpr bool wei_skip_all_out_of_bound_check = true;
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_y_x_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Y,
|
||||
Sequence<Ydot, Ytilda>,
|
||||
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{},
|
||||
Embed<X,
|
||||
Sequence<Xdot, Xtilda>,
|
||||
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
#if 1 // debug
|
||||
constexpr bool out_skip_all_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool out_skip_all_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
Embed<Ho,
|
||||
Sequence<Ydot, Htilda>,
|
||||
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{},
|
||||
Embed<Wo,
|
||||
Sequence<Xdot, Wtilda>,
|
||||
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<Ytilda>{},
|
||||
PassThrough<Xtilda>{},
|
||||
Slice<Sequence<Htilda, Wtilda>,
|
||||
Sequence<HtildaLeft, WtildaLeft>,
|
||||
Sequence<HtildaRight, WtildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
#if 1 // debug
|
||||
constexpr bool in_skip_all_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool in_skip_all_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
// input tensor
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Hip,
|
||||
Sequence<Ytilda, Htilda>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>,
|
||||
in_skip_all_out_of_bound_check>{},
|
||||
Embed<Wip,
|
||||
Sequence<Xtilda, Wtilda>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>,
|
||||
in_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<Ytilda>{},
|
||||
PassThrough<Xtilda>{},
|
||||
Slice<Sequence<Htilda, Wtilda>,
|
||||
Sequence<HtildaLeft, WtildaLeft>,
|
||||
Sequence<HtildaRight, WtildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
// GEMMs
|
||||
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
|
||||
|
||||
__shared__ Float p_shared_block[shared_block_size];
|
||||
|
||||
#if 1 // debug
|
||||
static_for<0, Ytilda, 1>{}([&](auto ytilda_) {
|
||||
static_for<0, Xtilda, 1>{}([&](auto xtilda_) {
|
||||
#else
|
||||
static_for<0, 1, 1>{}([&](auto ytilda_) {
|
||||
static_for<0, 1, 1>{}([&](auto xtilda_) {
|
||||
#endif
|
||||
constexpr index_t ytilda = decltype(ytilda_){};
|
||||
constexpr index_t xtilda = decltype(xtilda_){};
|
||||
|
||||
constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot;
|
||||
constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot;
|
||||
|
||||
// A matrix
|
||||
constexpr auto wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Slice<Sequence<Ydot, Xdot>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<YdotNonZero, XdotNonZero>>{},
|
||||
Slice<Sequence<Ytilda, Xtilda>,
|
||||
Sequence<ytilda, xtilda>,
|
||||
Sequence<ytilda + 1, xtilda + 1>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{},
|
||||
Merge<Sequence<C, 1, 1>>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// B matrix
|
||||
constexpr auto out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<HtildaTrim>{},
|
||||
PassThrough<WtildaTrim>{},
|
||||
Slice<Sequence<Ydot, Xdot>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<YdotNonZero, XdotNonZero>>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<5>{},
|
||||
Sequence<2, 4>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<5>{},
|
||||
Sequence<2, 4>{}));
|
||||
|
||||
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{},
|
||||
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C matrix
|
||||
constexpr auto in_n_c_1_htildatrim_1_wtildatrim_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<HtildaTrim>{},
|
||||
PassThrough<WtildaTrim>{},
|
||||
Slice<Sequence<Ytilda, Xtilda>,
|
||||
Sequence<ytilda, xtilda>,
|
||||
Sequence<ytilda + 1, xtilda + 1>>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<5>{},
|
||||
Sequence<2, 4>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<5>{},
|
||||
Sequence<2, 4>{}));
|
||||
|
||||
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_1_htildatrim_1_wtildatrim_global_desc,
|
||||
make_tuple(Merge<Sequence<C, 1, 1>>{},
|
||||
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalC_v1<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
decltype(out_gemmk_gemmn_global_desc),
|
||||
decltype(in_gemmm_gemmn_global_desc),
|
||||
InMemoryDataOperation::none,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global, p_shared_block);
|
||||
|
||||
// is synchronization necessary?
|
||||
__syncthreads();
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,333 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = C
|
||||
// GemmN = N * Htilda * Wtilda;
|
||||
// GemmK = K * YdotNonZero * XdotNonZero
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t Iter_ytilda,
|
||||
index_t Iter_xtilda,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmThreadGemmDataPerReadM,
|
||||
index_t GemmThreadGemmDataPerReadN,
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmM,
|
||||
index_t GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
index_t GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
|
||||
struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
{
|
||||
__device__ void Run(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
#if 0 // debug
|
||||
// sanity-check for vectorized memory load
|
||||
// TODO: this logic may not be correct for bwd-data
|
||||
static_assert(
|
||||
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
|
||||
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
|
||||
constexpr index_t Htilda =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
|
||||
constexpr bool wei_skip_all_out_of_bound_check = true;
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_y_x_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Y,
|
||||
Sequence<Ydot, Ytilda>,
|
||||
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{},
|
||||
Embed<X,
|
||||
Sequence<Xdot, Xtilda>,
|
||||
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
#if 1 // debug
|
||||
constexpr bool out_skip_all_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool out_skip_all_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
Embed<Ho,
|
||||
Sequence<Ydot, Htilda>,
|
||||
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{},
|
||||
Embed<Wo,
|
||||
Sequence<Xdot, Wtilda>,
|
||||
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<Ytilda>{},
|
||||
PassThrough<Xtilda>{},
|
||||
Slice<Sequence<Htilda, Wtilda>,
|
||||
Sequence<HtildaLeft, WtildaLeft>,
|
||||
Sequence<HtildaRight, WtildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
#if 1 // debug
|
||||
constexpr bool in_skip_all_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool in_skip_all_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
// input tensor
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Hip,
|
||||
Sequence<Ytilda, Htilda>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>,
|
||||
in_skip_all_out_of_bound_check>{},
|
||||
Embed<Wip,
|
||||
Sequence<Xtilda, Wtilda>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>,
|
||||
in_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<Ytilda>{},
|
||||
PassThrough<Xtilda>{},
|
||||
Slice<Sequence<Htilda, Wtilda>,
|
||||
Sequence<HtildaLeft, WtildaLeft>,
|
||||
Sequence<HtildaRight, WtildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
// GEMM
|
||||
constexpr index_t ytilda = Iter_ytilda;
|
||||
constexpr index_t xtilda = Iter_xtilda;
|
||||
|
||||
constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot;
|
||||
constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot;
|
||||
|
||||
// A matrix
|
||||
constexpr auto wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Slice<Sequence<Ydot, Xdot>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<YdotNonZero, XdotNonZero>>{},
|
||||
Slice<Sequence<Ytilda, Xtilda>,
|
||||
Sequence<ytilda, xtilda>,
|
||||
Sequence<ytilda + 1, xtilda + 1>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{}, Merge<Sequence<C, 1, 1>>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// B matrix
|
||||
constexpr auto out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<HtildaTrim>{},
|
||||
PassThrough<WtildaTrim>{},
|
||||
Slice<Sequence<Ydot, Xdot>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<YdotNonZero, XdotNonZero>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
|
||||
|
||||
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{},
|
||||
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C matrix
|
||||
constexpr auto in_n_c_1_htildatrim_1_wtildatrim_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<HtildaTrim>{},
|
||||
PassThrough<WtildaTrim>{},
|
||||
Slice<Sequence<Ytilda, Xtilda>,
|
||||
Sequence<ytilda, xtilda>,
|
||||
Sequence<ytilda + 1, xtilda + 1>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
|
||||
|
||||
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_1_htildatrim_1_wtildatrim_global_desc,
|
||||
make_tuple(Merge<Sequence<C, 1, 1>>{}, Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
constexpr auto gridwise_gemm =
|
||||
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
decltype(out_gemmk_gemmn_global_desc),
|
||||
decltype(in_gemmm_gemmn_global_desc),
|
||||
InMemoryDataOperation::none,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -181,12 +181,15 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(UnMerge<Sequence<N0, N1, N2>>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
|
||||
|
||||
|
||||
@@ -97,12 +97,15 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
|
||||
@@ -41,15 +41,17 @@ struct PassThrough
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// LowerLengths: Sequence<...>
|
||||
template <typename LowerLengths, typename LeftPads, typename RightPads>
|
||||
template <typename LowerLengths,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
bool SkipIsValidCheck = false>
|
||||
struct Pad
|
||||
{
|
||||
static constexpr index_t nDim = LowerLengths::Size();
|
||||
@@ -57,6 +59,13 @@ struct Pad
|
||||
using LowerIndex = MultiIndex<nDim>;
|
||||
using UpperIndex = MultiIndex<nDim>;
|
||||
|
||||
__host__ __device__ explicit constexpr Pad()
|
||||
{
|
||||
static_assert(LowerLengths::GetSize() == nDim && LeftPads::GetSize() == nDim &&
|
||||
RightPads::GetSize() == nDim,
|
||||
"wrong! # of dimensions not consistent");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; }
|
||||
@@ -81,20 +90,83 @@ struct Pad
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ constexpr bool
|
||||
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up) const
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
#if 1 // debug
|
||||
if(SkipIsValidCheck)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
bool flag = true;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
flag = flag && (idx_up[idim] >= LeftPads::At(idim)) &&
|
||||
(idx_up[idim] < LeftPads::At(idim) + LowerLengths::At(idim));
|
||||
});
|
||||
for(index_t i = 0; i < nDim; ++i)
|
||||
{
|
||||
flag = flag && LeftPads::At(i) == 0 && RightPads::At(i) == 0;
|
||||
}
|
||||
|
||||
return flag;
|
||||
}
|
||||
};
|
||||
|
||||
// LowerLengths: Sequence<...>
|
||||
// SliceBegins: Sequence<...>
|
||||
// SliceEnds: Sequence<...>
|
||||
template <typename LowerLengths, typename SliceBegins, typename SliceEnds>
|
||||
struct Slice
|
||||
{
|
||||
static constexpr index_t nDim = LowerLengths::Size();
|
||||
|
||||
using LowerIndex = MultiIndex<nDim>;
|
||||
using UpperIndex = MultiIndex<nDim>;
|
||||
|
||||
__host__ __device__ explicit constexpr Slice()
|
||||
{
|
||||
static_assert(LowerLengths::GetSize() == nDim && SliceBegins::GetSize() == nDim &&
|
||||
SliceEnds::GetSize() == nDim,
|
||||
"wrong! # of dimensions not consistent");
|
||||
|
||||
#if 0
|
||||
// TODO: would not compile, error on constexpr
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
static_assert(SliceBegins::At(idim) <= SliceEnds::At(idim) &&
|
||||
SliceBegins::At(idim) >= 0 &&
|
||||
SliceEnds::At(idim) <= LowerLengths::At(idim),
|
||||
"wrong! Slice config is wrong");
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
{
|
||||
return SliceEnds{} - SliceBegins{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
|
||||
{
|
||||
return idx_up + SliceBegins{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
|
||||
const UpperIndex& /* idx_up_old */,
|
||||
const LowerIndex& /* idx_low_old */)
|
||||
{
|
||||
return idx_up_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// LowerLengths: Sequence<...>
|
||||
template <typename LowerLengths>
|
||||
struct Merge
|
||||
@@ -165,85 +237,101 @@ struct Merge
|
||||
const UpperIndex& /* idx_up_old */,
|
||||
const LowerIndex& idx_low_old)
|
||||
{
|
||||
// do nothing if idx_up_diff == 0
|
||||
if(idx_up_diff[0] == 0)
|
||||
{
|
||||
return make_zero_array<index_t, nDimLow>();
|
||||
}
|
||||
|
||||
// CalculateLowerIndex(idx_up_diff) has multiple integer divisions.
|
||||
// If idx_up_diff is known at compile-time, the calculation can
|
||||
// be done at compile-time. However, if idx_up_diff is only known
|
||||
// at run-time, then the calculation will also be computed at
|
||||
// run-time, and can be very expensive.
|
||||
LowerIndex idx_low_new = idx_low_old + CalculateLowerIndex(idx_up_diff);
|
||||
|
||||
if(idx_up_diff[0] > 0)
|
||||
else
|
||||
{
|
||||
bool carry = false;
|
||||
// CalculateLowerIndex(idx_up_diff) has multiple integer divisions.
|
||||
// If idx_up_diff is known at compile-time, the calculation can
|
||||
// be done at compile-time. However, if idx_up_diff is only known
|
||||
// at run-time, then the calculation will also be computed at
|
||||
// run-time, and can be very expensive.
|
||||
LowerIndex idx_low_diff_tmp = CalculateLowerIndex(idx_up_diff);
|
||||
|
||||
// do carry check in reversed order, starting from lowest dimension
|
||||
// don't check the highest dimension
|
||||
static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) {
|
||||
constexpr index_t i = nDimLow - 1 - ireverse;
|
||||
// find out the last low dimension that changed
|
||||
index_t last_changed_low_dim = 0;
|
||||
|
||||
static_for<0, nDimLow, 1>{}([&](auto i) {
|
||||
if(idx_low_diff_tmp[i] != 0)
|
||||
{
|
||||
last_changed_low_dim = i;
|
||||
}
|
||||
});
|
||||
|
||||
LowerIndex idx_low_new = idx_low_old + idx_low_diff_tmp;
|
||||
|
||||
if(idx_up_diff[0] > 0)
|
||||
{
|
||||
// do carry check on each low dimension in reversed order
|
||||
// starting from the first digit that changed
|
||||
// don't check the highest dimension
|
||||
bool carry = false;
|
||||
|
||||
static_for<nDimLow - 1, 0, -1>{}([&](auto i) {
|
||||
if(i <= last_changed_low_dim)
|
||||
{
|
||||
if(carry)
|
||||
{
|
||||
++idx_low_new(i);
|
||||
}
|
||||
|
||||
carry = false;
|
||||
|
||||
if(idx_low_new[i] >= LowerLengths::At(i))
|
||||
{
|
||||
idx_low_new(i) -= LowerLengths::At(i);
|
||||
carry = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// highest dimension, no out-of-bound check
|
||||
if(carry)
|
||||
{
|
||||
++idx_low_new(i);
|
||||
++idx_low_new(0);
|
||||
}
|
||||
|
||||
carry = false;
|
||||
|
||||
if(idx_low_new[i] >= LowerLengths::At(i))
|
||||
{
|
||||
idx_low_new(i) -= LowerLengths::At(i);
|
||||
carry = true;
|
||||
}
|
||||
});
|
||||
|
||||
// highest dimension, no out-of-bound check
|
||||
if(carry)
|
||||
{
|
||||
++idx_low_new(0);
|
||||
}
|
||||
}
|
||||
else if(idx_up_diff[0] < 0)
|
||||
{
|
||||
bool borrow = false;
|
||||
else
|
||||
{
|
||||
// do borrow check on each low dimension in reversed order
|
||||
// starting from the first digit that changed
|
||||
// don't check the highest dimension
|
||||
bool borrow = false;
|
||||
|
||||
// do borrow check in reversed order, starting from lowest dimension
|
||||
// don't check the highest dimension
|
||||
static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) {
|
||||
constexpr index_t i = nDimLow - 1 - ireverse;
|
||||
static_for<nDimLow - 1, 0, -1>{}([&](auto i) {
|
||||
if(i <= last_changed_low_dim)
|
||||
{
|
||||
if(borrow)
|
||||
{
|
||||
--idx_low_new(i);
|
||||
}
|
||||
|
||||
borrow = false;
|
||||
|
||||
if(idx_low_new[i] < 0)
|
||||
{
|
||||
idx_low_new(i) += LowerLengths::At(i);
|
||||
borrow = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// highest dimension, no out-of-bound check
|
||||
if(borrow)
|
||||
{
|
||||
--idx_low_new(i);
|
||||
--idx_low_new(0);
|
||||
}
|
||||
|
||||
borrow = false;
|
||||
|
||||
if(idx_low_new[i] < 0)
|
||||
{
|
||||
idx_low_new(i) += LowerLengths::At(i);
|
||||
borrow = true;
|
||||
}
|
||||
});
|
||||
|
||||
// highest dimension, no out-of-bound check
|
||||
if(borrow)
|
||||
{
|
||||
--idx_low_new(0);
|
||||
}
|
||||
}
|
||||
|
||||
return idx_low_new - idx_low_old;
|
||||
return idx_low_new - idx_low_old;
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
@@ -290,8 +378,7 @@ struct UnMerge
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
@@ -300,7 +387,10 @@ struct UnMerge
|
||||
// UpperLengths: Sequence<...>
|
||||
// Coefficients: Sequence<...>
|
||||
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
|
||||
template <typename UpperLengths, typename Coefficients>
|
||||
template <index_t LowerLength,
|
||||
typename UpperLengths,
|
||||
typename Coefficients,
|
||||
bool SkipIsValidCheck = false>
|
||||
struct Embed
|
||||
{
|
||||
static constexpr index_t nDimLow = 1;
|
||||
@@ -325,8 +415,10 @@ struct Embed
|
||||
{
|
||||
LowerIndex idx_low(Coefficients{}[nDimUp]);
|
||||
|
||||
static_for<0, nDimUp, 1>{}(
|
||||
[&](auto idim) { idx_low(0) += idx_up[idim] * Coefficients{}[idim]; });
|
||||
for(index_t i = 0; i < nDimUp; ++i)
|
||||
{
|
||||
idx_low(0) += idx_up[i] * Coefficients{}[i];
|
||||
}
|
||||
|
||||
return idx_low;
|
||||
}
|
||||
@@ -338,18 +430,55 @@ struct Embed
|
||||
{
|
||||
LowerIndex idx_low_diff{0};
|
||||
|
||||
static_for<0, nDimUp, 1>{}(
|
||||
[&](auto idim) { idx_low_diff(0) += idx_up_diff[idim] * Coefficients{}[idim]; });
|
||||
for(index_t i = 0; i < nDimUp; ++i)
|
||||
{
|
||||
idx_low_diff(0) += idx_up_diff[i] * Coefficients{}[i];
|
||||
}
|
||||
|
||||
return idx_low_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
#if 1 // debug
|
||||
if(SkipIsValidCheck)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
bool flag = true;
|
||||
|
||||
index_t ncorner = 1;
|
||||
|
||||
for(index_t idim = 0; idim < nDimUp; ++idim)
|
||||
{
|
||||
ncorner *= 2;
|
||||
}
|
||||
|
||||
// loop over each corner of the upper tensor
|
||||
for(index_t icorner = 0; icorner < ncorner; ++icorner)
|
||||
{
|
||||
// generate upper index for each corner
|
||||
auto idx_up = make_zero_array<index_t, nDimUp>();
|
||||
|
||||
index_t itmp = icorner;
|
||||
|
||||
for(index_t idim = nDimUp - 1; idim >= 0; --idim)
|
||||
{
|
||||
idx_up(idim) = itmp % 2 == 0 ? 0 : UpperLengths::At(idim) - 1;
|
||||
itmp /= 2;
|
||||
}
|
||||
|
||||
// calculate lower index
|
||||
auto idx_low = CalculateLowerIndex(idx_up);
|
||||
|
||||
// judge if lower index is valid
|
||||
flag = flag && idx_low[0] >= 0 && idx_low[0] < LowerLength;
|
||||
}
|
||||
|
||||
return flag;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -389,8 +518,7 @@ struct Vectorize
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -53,6 +53,8 @@ struct NativeTensorCoordinate
|
||||
|
||||
__host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; }
|
||||
|
||||
__host__ __device__ constexpr const Index& GetUpperIndex() const { return mIndex; }
|
||||
|
||||
__host__ __device__ constexpr const Index& GetIndex() const { return mIndex; }
|
||||
|
||||
__host__ __device__ constexpr const index_t& GetOffset() const { return mOffset; }
|
||||
@@ -98,7 +100,24 @@ struct NativeTensorCoordinate
|
||||
return tensor_desc_type::CalculateOffsetDiff(idx_diff);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsUpperIndexMappedToValidOffset() { return true; }
|
||||
// evaluated at run-time
|
||||
__host__ __device__ constexpr bool IsUpperIndexValid() const
|
||||
{
|
||||
return tensor_desc_type::IsUpperIndexValid(GetUpperIndex());
|
||||
}
|
||||
|
||||
// evaluated at run-time
|
||||
__host__ __device__ constexpr bool IsOffsetValid() const
|
||||
{
|
||||
// For native tensor, offset is valid if upper-index is valid
|
||||
return IsUpperIndexValid();
|
||||
}
|
||||
|
||||
// evaluated at compile-time
|
||||
__host__ __device__ static constexpr bool IsOffsetValidAssumingUpperIndexIsValid()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
// mIndex may be saved and updated, however, the value of some (or all) of its entries may
|
||||
@@ -206,10 +225,30 @@ struct TransformedTensorCoordinate
|
||||
return GetLowerCoordinate().CalculateOffsetDiff(idx_low_diff);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr bool IsUpperIndexMappedToValidOffset() const
|
||||
// evaluated at run-time
|
||||
__host__ __device__ constexpr bool IsUpperIndexValid() const
|
||||
{
|
||||
return tensor_desc_type::IsUpperIndexMappedToValidLowerIndex(GetIndex()) &&
|
||||
mCoordLow.IsUpperIndexMappedToValidOffset();
|
||||
return tensor_desc_type::IsUpperIndexValid(GetUpperIndex());
|
||||
}
|
||||
|
||||
// evaluted at run-time
|
||||
__host__ __device__ constexpr bool IsOffsetValid() const
|
||||
{
|
||||
return IsUpperIndexValid() && GetLowerCoordinate().IsOffsetValid();
|
||||
}
|
||||
|
||||
// most evaluatation is done at comile-time
|
||||
__host__ __device__ constexpr bool IsLowerIndexValidAssumingUpperIndexIsValid() const
|
||||
{
|
||||
return tensor_desc_type::IsLowerIndexValidAssumingUpperIndexIsValid(
|
||||
GetLowerCoordinate().GetIndex());
|
||||
}
|
||||
|
||||
// most evaluatation is done at comile-time
|
||||
__host__ __device__ constexpr bool IsOffsetValidAssumingUpperIndexIsValid() const
|
||||
{
|
||||
return IsLowerIndexValidAssumingUpperIndexIsValid() &&
|
||||
GetLowerCoordinate().IsOffsetValidAssumingUpperIndexIsValid();
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@@ -120,10 +120,17 @@ struct NativeTensorDescriptor
|
||||
return Tuple<>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
IsUpperIndexMappedToValidOffset(const Index& /* idx */)
|
||||
// a multi-index is valid if there is a corresponding point for it in the tensor
|
||||
__host__ __device__ static constexpr bool IsUpperIndexValid(const Index& idx)
|
||||
{
|
||||
return true;
|
||||
bool flag = true;
|
||||
|
||||
for(index_t i = 0; i < nDim; ++i)
|
||||
{
|
||||
flag = flag && idx[i] >= 0 && idx[i] < GetLengths()[i];
|
||||
}
|
||||
|
||||
return flag;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -467,33 +474,49 @@ struct TransformedTensorDescriptor
|
||||
}
|
||||
#endif
|
||||
|
||||
// a multi-index is valid if there is a corresponding point for it in the tensor
|
||||
__host__ __device__ constexpr bool IsUpperIndexValid(const UpperIndex& idx_up) const
|
||||
{
|
||||
bool flag = true;
|
||||
|
||||
for(index_t i = 0; i < nDimUp; ++i)
|
||||
{
|
||||
flag = flag && idx_up[i] >= 0 && idx_up[i] < GetLengths()[i];
|
||||
}
|
||||
|
||||
return flag;
|
||||
}
|
||||
|
||||
// this function is for optimization purpose, it's called by tensor coordinate
|
||||
// this function tells you: If a lower-index is valid or not, assuming upper index is valid
|
||||
__host__ __device__ static constexpr bool
|
||||
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up)
|
||||
IsLowerIndexValidAssumingUpperIndexIsValid(const LowerIndex& idx_low)
|
||||
{
|
||||
bool flag = true;
|
||||
|
||||
static_for<0, nTransform, 1>{}([&](auto itran) {
|
||||
constexpr auto tran = Transforms{}.At(itran);
|
||||
|
||||
const auto idx_up_part = pick_array_element(idx_up, UpDimensionIds{}.At(itran));
|
||||
// check a indtransformation if it does not always has a valid mapping
|
||||
constexpr bool is_valid_up_always_mapped_to_valid_low =
|
||||
decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex();
|
||||
|
||||
flag = flag && tran.IsUpperIndexMappedToValidLowerIndex(to_array(idx_up_part));
|
||||
if(!is_valid_up_always_mapped_to_valid_low)
|
||||
{
|
||||
constexpr auto low_dims_part = LowDimensionIds{}.At(itran);
|
||||
constexpr auto low_lengths_part =
|
||||
GetLowerTensorDescriptor().GetLengths(low_dims_part);
|
||||
const auto idx_low_part = to_array(pick_array_element(idx_low, low_dims_part));
|
||||
|
||||
for(index_t i = 0; i < low_dims_part.Size(); ++i)
|
||||
{
|
||||
flag = flag && idx_low_part[i] >= 0 && idx_low_part[i] < low_lengths_part[i];
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return flag;
|
||||
}
|
||||
|
||||
// Whenever this function is called, it will call CalculateLowerIndex() recursively.
|
||||
// If you have created a tensor coordinate already, instead of calling this function,
|
||||
// you should call TensorCoordinate::IsUpperIndexMappedToValidOffset() which would
|
||||
// be less expensive.
|
||||
__host__ __device__ static constexpr bool
|
||||
IsUpperIndexMappedToValidOffset(const UpperIndex& idx_up)
|
||||
{
|
||||
return IsUpperIndexMappedToValidLowerIndex(idx_up) &&
|
||||
GetLowerTensorDescriptor().IsUpperIndexMappedToValidOffset(
|
||||
CalculateLowerIndex(idx_up));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -50,9 +50,37 @@ template <index_t GridSize,
|
||||
index_t CThreadCopyDstDataPerWrite>
|
||||
struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
ThreadGemmDataPerReadM,
|
||||
ThreadGemmDataPerReadN);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr index_t a_block_space =
|
||||
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
constexpr index_t b_block_space =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_space + b_block_space) * sizeof(Float);
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_a_global,
|
||||
const Float* __restrict__ p_b_global,
|
||||
Float* __restrict__ p_c_global) const
|
||||
Float* __restrict__ p_c_global,
|
||||
Float* __restrict__ p_shared_block) const
|
||||
{
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
@@ -64,6 +92,12 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
constexpr auto M = a_k_m_global_desc.GetLengths()[1];
|
||||
constexpr auto N = b_k_n_global_desc.GetLengths()[1];
|
||||
|
||||
// don't do anything if K == 0
|
||||
if(K == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// lds max alignment
|
||||
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
@@ -184,8 +218,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
constexpr index_t b_block_space =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
__shared__ Float p_a_block_double[2 * a_block_space];
|
||||
__shared__ Float p_b_block_double[2 * b_block_space];
|
||||
Float* p_a_block_double = p_shared_block;
|
||||
Float* p_b_block_double = p_shared_block + 2 * a_block_space;
|
||||
|
||||
// register allocation for output
|
||||
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
|
||||
@@ -329,6 +363,17 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
.Run(p_c_thread, p_c_global);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_a_global,
|
||||
const Float* __restrict__ p_b_global,
|
||||
Float* __restrict__ p_c_global) const
|
||||
{
|
||||
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
|
||||
|
||||
__shared__ Float p_shared_block[shared_block_size];
|
||||
|
||||
Run(p_a_global, p_b_global, p_c_global, p_shared_block);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -110,7 +110,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// Check src data's valid mapping situation, only check the first data in this src
|
||||
// vector. It's user's responsiblity to make sure all data in the src vector
|
||||
// has the valid/invalid mapping situation
|
||||
if(src_coord.IsUpperIndexMappedToValidOffset())
|
||||
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
@@ -142,7 +142,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// Check dst data's valid mapping situation, only check the first data in this dst
|
||||
// vector. It's user's responsiblity to make sure all data in the dst vector
|
||||
// has the valid/invalid mapping situation
|
||||
if(dst_coord.IsUpperIndexMappedToValidOffset())
|
||||
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<DstData,
|
||||
DstDataPerWrite,
|
||||
@@ -260,7 +260,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// src
|
||||
// vector. It's user's responsiblity to make sure all data in the src vector
|
||||
// has the valid/invalid mapping situation
|
||||
if(src_coord.IsUpperIndexMappedToValidOffset())
|
||||
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
@@ -299,7 +299,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// dst
|
||||
// vector. It's user's responsiblity to make sure all data in the dst vector
|
||||
// has the valid/invalid mapping situation
|
||||
if(dst_coord.IsUpperIndexMappedToValidOffset())
|
||||
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<DstData,
|
||||
DstDataPerWrite,
|
||||
@@ -399,7 +399,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// src
|
||||
// vector. It's user's responsiblity to make sure all data in the src vector
|
||||
// has the valid/invalid mapping situation
|
||||
if(src_coord.IsUpperIndexMappedToValidOffset())
|
||||
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
@@ -444,7 +444,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// dst
|
||||
// vector. It's user's responsiblity to make sure all data in the dst vector
|
||||
// has the valid/invalid mapping situation
|
||||
if(dst_coord.IsUpperIndexMappedToValidOffset())
|
||||
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<DstData,
|
||||
DstDataPerWrite,
|
||||
|
||||
@@ -54,6 +54,13 @@ __device__ void __llvm_amdgcn_buffer_storex4(float4_t vdata,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store.v4f32");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_buffer_atomic_add(float vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.atomic.fadd.f32");
|
||||
|
||||
// buffer_load requires:
|
||||
// 1) p_src must be in global memory space, d_dst must be vgpr
|
||||
// 2) p_src to be a block-invariant pointer.
|
||||
@@ -73,6 +80,13 @@ amd_intrinsic_buffer_store(const typename vector_type<T, VectorSize>::MemoryType
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset);
|
||||
|
||||
template <typename T, index_t VectorSize>
|
||||
__device__ void
|
||||
amd_intrinsic_buffer_atomic_add(const typename vector_type<T, VectorSize>::MemoryType& src,
|
||||
T* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset);
|
||||
|
||||
template <>
|
||||
__device__ float amd_intrinsic_buffer_load<float, 1>(const float* p_src_block,
|
||||
index_t src_thread_data_offset,
|
||||
@@ -289,5 +303,31 @@ __device__ void amd_intrinsic_buffer_store<float, 4>(const float4_t& src,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_intrinsic_buffer_atomic_add<float, 1>(const float& src,
|
||||
float* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
|
||||
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
|
||||
|
||||
BufferLoadStoreDwordConfig<float> dst_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
dst_block_config.address[0] = p_dst_block;
|
||||
// fill in byte 2
|
||||
dst_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
dst_block_config.range[3] = 0x00027000;
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC
|
||||
__llvm_amdgcn_buffer_atomic_add(
|
||||
src, dst_block_config.data, 0, dst_thread_addr_offset + dst_const_addr_offset, false);
|
||||
#else
|
||||
static_assert(false, " wrong! not implemented");
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -29,6 +29,11 @@
|
||||
#define CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC 1
|
||||
#endif
|
||||
|
||||
// only support gfx908
|
||||
#ifndef CK_USE_AMD_BUFFER_ATOMIC_ADD
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_ADD 0
|
||||
#endif
|
||||
|
||||
// AMD XDLOPS
|
||||
#ifndef CK_USE_AMD_XDLOPS
|
||||
#define CK_USE_AMD_XDLOPS 0
|
||||
|
||||
@@ -29,9 +29,10 @@ struct static_for
|
||||
{
|
||||
__host__ __device__ constexpr static_for()
|
||||
{
|
||||
static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd");
|
||||
static_assert((NEnd - NBegin) % Increment == 0,
|
||||
static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0,
|
||||
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
|
||||
static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd),
|
||||
"wrongs! should have NBegin <= NEnd");
|
||||
}
|
||||
|
||||
template <class F>
|
||||
|
||||
@@ -52,8 +52,13 @@ __device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, in
|
||||
|
||||
static_if<SrcAddressSpace == AddressSpace::vgpr && DstAddressSpace == AddressSpace::global>{}(
|
||||
[&](auto) {
|
||||
#if CK_USE_AMD_BUFFER_ATOMIC_ADD
|
||||
amd_intrinsic_buffer_atomic_add<T, DataPerAccess>(
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]), p_dst, dst_offset, 0);
|
||||
#else
|
||||
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
|
||||
#endif
|
||||
})
|
||||
.Else([&](auto fwd) {
|
||||
static_assert(fwd(false), "atomic_add doesn't support this memory space");
|
||||
|
||||
@@ -49,6 +49,12 @@ struct integer_divide_ceiler
|
||||
}
|
||||
};
|
||||
|
||||
template <class X, class Y>
|
||||
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
|
||||
{
|
||||
return x / y;
|
||||
}
|
||||
|
||||
template <class X, class Y>
|
||||
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
|
||||
{
|
||||
|
||||
@@ -24,9 +24,9 @@ elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
|
||||
set(CONV_BWD_DATA_SOURCE src/conv_bwd_data_driver.cu)
|
||||
endif()
|
||||
|
||||
add_executable(conv ${CONV_SOURCE})
|
||||
add_executable(col2im ${COL2IM_SOURCE})
|
||||
add_executable(conv_bwd_data ${CONV_BWD_DATA_SOURCE})
|
||||
target_link_libraries(conv PRIVATE host)
|
||||
target_link_libraries(col2im PRIVATE host)
|
||||
target_link_libraries(conv_bwd_data PRIVATE host)
|
||||
add_executable(conv_driver ${CONV_SOURCE})
|
||||
add_executable(col2im_driver ${COL2IM_SOURCE})
|
||||
add_executable(conv_bwd_data_driver ${CONV_BWD_DATA_SOURCE})
|
||||
target_link_libraries(conv_driver PRIVATE host)
|
||||
target_link_libraries(col2im_driver PRIVATE host)
|
||||
target_link_libraries(conv_bwd_data_driver PRIVATE host)
|
||||
|
||||
@@ -30,33 +30,81 @@ struct KernelTimer
|
||||
std::unique_ptr<KernelTimerImpl> impl;
|
||||
};
|
||||
|
||||
#if CK_DEVICE_BACKEND_AMD
|
||||
using device_stream_t = hipStream_t;
|
||||
|
||||
template <typename... Args, typename F>
|
||||
float launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
|
||||
void launch_kernel(F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
hipStream_t stream_id,
|
||||
Args... args)
|
||||
{
|
||||
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
|
||||
}
|
||||
|
||||
template <typename... Args, typename F>
|
||||
float launch_and_time_kernel(F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
hipStream_t stream_id,
|
||||
Args... args)
|
||||
{
|
||||
KernelTimer timer;
|
||||
|
||||
#if CK_DEVICE_BACKEND_AMD
|
||||
timer.Start();
|
||||
|
||||
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, 0, args...);
|
||||
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
|
||||
|
||||
timer.End();
|
||||
|
||||
hipGetErrorString(hipGetLastError());
|
||||
|
||||
return timer.GetElapsedTime();
|
||||
}
|
||||
|
||||
#elif CK_DEVICE_BACKEND_NVIDIA
|
||||
using device_stream_t = cudaStream_t;
|
||||
|
||||
template <typename... Args, typename F>
|
||||
void launch_kernel(F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
cudaStream_t stream_id,
|
||||
Args... args)
|
||||
{
|
||||
const void* f = reinterpret_cast<const void*>(kernel);
|
||||
void* p_args[] = {&args...};
|
||||
|
||||
cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, stream_id);
|
||||
}
|
||||
|
||||
template <typename... Args, typename F>
|
||||
float launch_and_time_kernel(F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
cudaStream_t stream_id,
|
||||
Args... args)
|
||||
{
|
||||
KernelTimer timer;
|
||||
|
||||
const void* f = reinterpret_cast<const void*>(kernel);
|
||||
void* p_args[] = {&args...};
|
||||
|
||||
timer.Start();
|
||||
|
||||
cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, 0);
|
||||
cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, stream_id);
|
||||
|
||||
timer.End();
|
||||
|
||||
checkCudaErrors(error);
|
||||
#endif
|
||||
|
||||
return timer.GetElapsedTime();
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
@@ -88,7 +88,8 @@ void device_col2im_eb_nchw(ColDesc,
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_col2im),
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_col2im),
|
||||
const T* const __restrict__,
|
||||
T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
|
||||
@@ -5,6 +5,10 @@
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
namespace launcher {
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
@@ -121,20 +125,18 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gridwise_conv,
|
||||
const_cast<T* const __restrict__>(
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
|
||||
float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
gridwise_conv,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
|
||||
printf("Elapsed time : %f ms, %f TFlop/s\n",
|
||||
time,
|
||||
@@ -145,3 +147,5 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
|
||||
|
||||
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
|
||||
}
|
||||
|
||||
} // namespace launcher
|
||||
|
||||
@@ -5,6 +5,10 @@
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp"
|
||||
|
||||
namespace launcher {
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
@@ -129,20 +133,18 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gridwise_conv,
|
||||
const_cast<T* const __restrict__>(
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
|
||||
float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
gridwise_conv,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
|
||||
printf("Elapsed time : %f ms, %f TFlop/s\n",
|
||||
time,
|
||||
@@ -153,3 +155,5 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
|
||||
|
||||
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
|
||||
}
|
||||
|
||||
} // namespace launcher
|
||||
|
||||
@@ -5,6 +5,10 @@
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
namespace launcher {
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
@@ -27,12 +31,16 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
|
||||
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
|
||||
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
|
||||
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
|
||||
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
|
||||
|
||||
constexpr index_t Hi = in_nchw_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_nchw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
|
||||
|
||||
@@ -81,6 +89,67 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
// for 1x1 weight, 8x8 input
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
@@ -92,14 +161,24 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
|
||||
constexpr index_t right_pad_ho = (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda);
|
||||
constexpr index_t right_pad_wo = (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda);
|
||||
constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t Htilda = Ho + right_pad_ho;
|
||||
constexpr index_t Wtilda = Wo + right_pad_wo;
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
|
||||
constexpr index_t GemmM = C * Ytilda * Xtilda;
|
||||
constexpr index_t GemmN = N * Htilda * Wtilda;
|
||||
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
@@ -142,20 +221,18 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gridwise_conv,
|
||||
const_cast<T* const __restrict__>(
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
|
||||
float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
gridwise_conv,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
|
||||
printf("Elapsed time : %f ms, %f TFlop/s\n",
|
||||
time,
|
||||
@@ -166,3 +243,5 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
|
||||
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
|
||||
}
|
||||
|
||||
} // namespace launcher
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
namespace launcher {
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
|
||||
Tensor<T>& in_nchw,
|
||||
WeiDesc wei_kcyx_desc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc out_nkhw_desc,
|
||||
const Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
|
||||
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
|
||||
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
|
||||
|
||||
constexpr index_t Hi = in_nchw_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_nchw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
|
||||
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
|
||||
|
||||
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
|
||||
constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
|
||||
constexpr index_t GemmM = C;
|
||||
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
constexpr auto gridwise_conv = GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
gridwise_conv,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
|
||||
printf("Elapsed time : %f ms, %f TFlop/s\n",
|
||||
time,
|
||||
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / time);
|
||||
usleep(std::min(time * 1000, float(10000)));
|
||||
}
|
||||
|
||||
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
|
||||
}
|
||||
|
||||
} // namespace launcher
|
||||
@@ -0,0 +1,232 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
namespace launcher {
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
|
||||
Tensor<T>& in_nchw,
|
||||
WeiDesc wei_kcyx_desc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc out_nkhw_desc,
|
||||
const Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
|
||||
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
|
||||
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
|
||||
|
||||
constexpr index_t Hi = in_nchw_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_nchw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
|
||||
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
|
||||
|
||||
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<8, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
|
||||
constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
|
||||
constexpr index_t GemmM = C;
|
||||
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
KernelTimer timer;
|
||||
|
||||
timer.Start();
|
||||
|
||||
static_for<0, Ytilda, 1>{}([&](auto ytilda_) {
|
||||
static_for<0, Xtilda, 1>{}([&](auto xtilda_) {
|
||||
constexpr index_t ytilda = decltype(ytilda_){};
|
||||
constexpr index_t xtilda = decltype(xtilda_){};
|
||||
|
||||
constexpr auto gridwise_conv =
|
||||
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
ytilda,
|
||||
xtilda,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
gridwise_conv,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
});
|
||||
});
|
||||
|
||||
timer.End();
|
||||
|
||||
float time = timer.GetElapsedTime();
|
||||
|
||||
printf("Elapsed time : %f ms, %f TFlop/s\n",
|
||||
time,
|
||||
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / time);
|
||||
usleep(std::min(time * 1000, float(10000)));
|
||||
}
|
||||
|
||||
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
|
||||
}
|
||||
|
||||
} // namespace launcher
|
||||
@@ -82,13 +82,13 @@ void device_convolution_direct_v2_nchw_kcyx_nkhw(InDesc,
|
||||
WoPerThread,
|
||||
InBlockCopyDataPerRead,
|
||||
WeiBlockCopyDataPerRead>;
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<gridwise_conv, T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
static_cast<T*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_device_buf.GetDeviceBuffer()));
|
||||
float time = launch_and_time_kernel(run_gridwise_convolution_kernel<gridwise_conv, T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
static_cast<T*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_device_buf.GetDeviceBuffer()));
|
||||
|
||||
printf("Elapsed time : %f ms\n", time);
|
||||
usleep(std::min(time * 1000, float(10000)));
|
||||
|
||||
@@ -458,7 +458,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
|
||||
@@ -161,7 +161,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc,
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
|
||||
@@ -354,7 +354,8 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
OutThreadCopyDataPerWrite_W>{};
|
||||
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
|
||||
@@ -306,7 +306,8 @@ void device_convolution_implicit_gemm_v2_chwn_cyxk_khwn(InDesc,
|
||||
WeiBlockCopyDataPerRead,
|
||||
OutThreadCopyDataPerWrite>{};
|
||||
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
|
||||
@@ -135,7 +135,8 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
|
||||
WeiBlockCopyClusterLengths_C_K,
|
||||
WeiBlockCopyDataPerAccess_K>{};
|
||||
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
|
||||
@@ -54,7 +54,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
// BlockSize = 256, EperBlock = 8, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -128,7 +128,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// BlockSize = 64, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
@@ -258,13 +258,15 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
gridwise_conv,
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
|
||||
|
||||
@@ -81,6 +81,43 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// BlockSize = 256, EPerBlock = 16, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t EPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
@@ -247,10 +284,12 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
|
||||
@@ -200,7 +200,8 @@ void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>{};
|
||||
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
|
||||
@@ -158,7 +158,8 @@ void device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw(InDesc,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>{};
|
||||
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
|
||||
@@ -53,7 +53,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// BlockSize = 256, GemmKPerBlock = 8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -83,6 +83,37 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// BlockSize = 256, GemmKPerBlock = 16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 0
|
||||
// BlockSize = 256, GemmKPerBlock = 8
|
||||
@@ -116,7 +147,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// BlockSize = 256, GemmKPerBlock = 16
|
||||
// 1x1 filter, 8x8 image
|
||||
constexpr index_t BlockSize = 256;
|
||||
@@ -225,10 +256,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
|
||||
@@ -205,7 +205,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated(InDesc,
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
|
||||
@@ -178,7 +178,7 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(
|
||||
float time = launch_and_time_kernel(
|
||||
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<TInWei,
|
||||
TOut,
|
||||
accum_t,
|
||||
|
||||
@@ -16,22 +16,25 @@
|
||||
#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
using namespace launcher;
|
||||
|
||||
#if 1
|
||||
constexpr index_t N = 8;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 16;
|
||||
constexpr index_t WI = 16;
|
||||
constexpr index_t K = 8;
|
||||
constexpr index_t Y = 2;
|
||||
constexpr index_t X = 2;
|
||||
#if 0
|
||||
// 3x3 filter, 2x2 stride, 35x35 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<4, 4>;
|
||||
using ConvDilations = Sequence<2, 2>;
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
@@ -41,7 +44,7 @@ int main(int argc, char* argv[])
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 34;
|
||||
constexpr index_t WI = 34;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
@@ -51,27 +54,27 @@ int main(int argc, char* argv[])
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 8x8 image
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 1536;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
// 3x3, 28x28
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 1x1 filter, 8x8 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 2048;
|
||||
constexpr index_t N = 256;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
@@ -83,25 +86,10 @@ int main(int argc, char* argv[])
|
||||
#elif 0
|
||||
// 1x1 filter, 7x7 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 832;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 8x8 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1280;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
@@ -123,27 +111,12 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 8x8 image
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 1536;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 28x28 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 128;
|
||||
@@ -153,105 +126,30 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 7x7 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 832;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 768;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 14x14 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 528;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 14x14 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 528;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 7x7 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 832;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 288;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 5x5 filter, 2x2 pad, 7x7 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 48;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 5;
|
||||
constexpr index_t X = 5;
|
||||
|
||||
@@ -260,28 +158,13 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<2, 2>;
|
||||
using RightPads = Sequence<2, 2>;
|
||||
#elif 0
|
||||
// 7x1 filter, 3x0 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
#elif 1
|
||||
// 1x7 filter, 0x3 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 7;
|
||||
|
||||
@@ -290,6 +173,36 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<0, 3>;
|
||||
using RightPads = Sequence<0, 3>;
|
||||
#elif 0
|
||||
// 7x1 filter, 3x0 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
#elif 0
|
||||
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#endif
|
||||
|
||||
constexpr auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence<N, C, HI, WI>{});
|
||||
@@ -337,8 +250,12 @@ int main(int argc, char* argv[])
|
||||
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
|
||||
#else
|
||||
#elif 1
|
||||
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw
|
||||
#elif 1
|
||||
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
|
||||
#endif
|
||||
(in_nchw_desc,
|
||||
in_nchw_device,
|
||||
|
||||
@@ -30,32 +30,63 @@ int main(int argc, char* argv[])
|
||||
using namespace ck;
|
||||
|
||||
#if 0
|
||||
constexpr index_t N = 8;
|
||||
constexpr index_t C = 32;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 32;
|
||||
constexpr index_t Y = 5;
|
||||
constexpr index_t X = 5;
|
||||
// 1x1
|
||||
constexpr index_t N = 256;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 1x7
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 7;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 3>;
|
||||
using RightPads = Sequence<0, 3>;
|
||||
#elif 0
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 34;
|
||||
constexpr index_t WI = 34;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
@@ -282,21 +313,6 @@ int main(int argc, char* argv[])
|
||||
using LeftPads = Sequence<2, 2>;
|
||||
using RightPads = Sequence<2, 2>;
|
||||
#elif 0
|
||||
// 7x1 filter, 3x0 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
#elif 1
|
||||
// 1x7 filter, 0x3 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
@@ -311,6 +327,21 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<0, 3>;
|
||||
using RightPads = Sequence<0, 3>;
|
||||
#elif 1
|
||||
// 7x1 filter, 3x0 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
#endif
|
||||
|
||||
auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{});
|
||||
|
||||
@@ -14,13 +14,11 @@ cmake
|
||||
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
|
||||
-D DEVICE_BACKEND=NVIDIA \
|
||||
-D CUDA_COMMON_INCLUDE_DIR="/package/install/cuda/10.1/NVIDIA_CUDA-10.1_Samples/common/inc" \
|
||||
-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61" \
|
||||
-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -maxrregcount=128" \
|
||||
${MY_PROJECT_SOURCE}
|
||||
|
||||
#-D BOOST_ROOT="/package/install/boost_1.67.0" \
|
||||
|
||||
#-D CMAKE_CUDA_COMPILER="/package/install/cuda_10.0/bin/nvcc" \
|
||||
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61" \
|
||||
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -Xptxas -v -maxrregcount=128" \
|
||||
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -gencode=arch=compute_70,code=sm_70" \
|
||||
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -gencode=arch=compute_70,code=sm_70 -Xptxas -v -maxrregcount=128" \
|
||||
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -maxrregcount=128" \
|
||||
|
||||
@@ -15,11 +15,9 @@ cmake
|
||||
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
|
||||
-D DEVICE_BACKEND=NVIDIA \
|
||||
-D CUDA_COMMON_INCLUDE_DIR="/root/NVIDIA_CUDA-10.1_Samples/common/inc" \
|
||||
-D CMAKE_CUDA_FLAGS="-ccbin clang++-6.0 -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_60,code=sm_60 -Xptxas -v -gencode=arch=compute_70,code=sm_70" \
|
||||
-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -maxrregcount=128" \
|
||||
${MY_PROJECT_SOURCE}
|
||||
|
||||
|
||||
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61" \
|
||||
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -Xptxas -v -maxrregcount=128" \
|
||||
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -gencode=arch=compute_70,code=sm_70" \
|
||||
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -gencode=arch=compute_70,code=sm_70 -Xptxas -v -maxrregcount=128" \
|
||||
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -maxrregcount=128" \
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
cuobjdump -xelf sm_60 ./driver/driver && nvdisasm --print-code -g driver.sm_60.cubin > driver.sm_60.asm
|
||||
cuobjdump -xelf sm_61 ./driver/driver && nvdisasm --print-code -g driver.sm_61.cubin > driver.sm_61.asm
|
||||
cuobjdump -xelf sm_70 ./driver/driver && nvdisasm --print-code -g driver.sm_70.cubin > driver.sm_70.asm
|
||||
DRIVER=$1
|
||||
ARCH=$2
|
||||
cuobjdump -xelf $ARCH ./driver/$DRIVER && nvdisasm --print-code -g $DRIVER.$ARCH.cubin > $DRIVER.$ARCH.asm
|
||||
|
||||
Reference in New Issue
Block a user