mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
Bwd Data NHWC (#22)
* fix buffer_store bug * remove obsolete kernels * add bwd-data-v5r1-nhwc
This commit is contained in:
@@ -1,268 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V2R1_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V2R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = C * YTilda * XTilda;
|
||||
// GemmN = N * HTildaSlice * WTildaSlice;
|
||||
// GemmK = K * YDot * XDot;
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThread,
|
||||
index_t GemmNPerThread,
|
||||
index_t GemmKPerThread,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmThreadGemmDataPerReadM,
|
||||
index_t GemmThreadGemmDataPerReadN,
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmM,
|
||||
index_t GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
index_t GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
|
||||
struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
{
|
||||
__device__ void Run(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
#if 0 // debug
|
||||
// sanity-check for vectorized memory load
|
||||
// TODO: this logic may not be correct for bwd-data
|
||||
static_assert(
|
||||
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
|
||||
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
#endif
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t HTilda =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
|
||||
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_y_x_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Y,
|
||||
Sequence<YDot, YTilda>,
|
||||
Sequence<ConvStrideH / GcdStrideDilationH, 1, 0>>{},
|
||||
Embed<X,
|
||||
Sequence<XDot, XTilda>,
|
||||
Sequence<ConvStrideW / GcdStrideDilationW, 1, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YDot, XDot>>{}, Merge<Sequence<C, YTilda, XTilda>>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
Embed<Ho,
|
||||
Sequence<YDot, HTilda>,
|
||||
Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>>{},
|
||||
Embed<Wo,
|
||||
Sequence<XDot, WTilda>,
|
||||
Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<YTilda>{},
|
||||
PassThrough<XTilda>{},
|
||||
Slice<Sequence<HTilda, WTilda>,
|
||||
Sequence<HTildaLeft, WTildaLeft>,
|
||||
Sequence<HTildaRight, WTildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
constexpr auto out_gemmk_gemmn_global_desc =
|
||||
transform_tensor_descriptor(out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YDot, XDot>>{},
|
||||
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
#if 1 // debug
|
||||
constexpr bool in_skip_all_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool in_skip_all_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
// input tensor
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Hip,
|
||||
Sequence<YTilda, HTilda>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>,
|
||||
in_skip_all_out_of_bound_check>{},
|
||||
Embed<Wip,
|
||||
Sequence<XTilda, WTilda>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>,
|
||||
in_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<YTilda>{},
|
||||
PassThrough<XTilda>{},
|
||||
Slice<Sequence<HTilda, WTilda>,
|
||||
Sequence<HTildaLeft, WTildaLeft>,
|
||||
Sequence<HTildaRight, WTildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
constexpr auto in_gemmm_gemmn_global_desc =
|
||||
transform_tensor_descriptor(in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc,
|
||||
make_tuple(Merge<Sequence<C, YTilda, XTilda>>{},
|
||||
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// GEMM
|
||||
constexpr auto gridwise_gemm =
|
||||
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
decltype(out_gemmk_gemmn_global_desc),
|
||||
decltype(in_gemmm_gemmn_global_desc),
|
||||
InMemoryDataOperation::Set,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,388 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Number of GEMMs: YTilda * XTilda
|
||||
// GemmM = C
|
||||
// GemmN = N * HTildaSlice * WTildaSlice
|
||||
// GemmK = K * YDotSlice * XDotSlice
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThread,
|
||||
index_t GemmNPerThread,
|
||||
index_t GemmKPerThread,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmThreadGemmDataPerReadM,
|
||||
index_t GemmThreadGemmDataPerReadN,
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmM,
|
||||
index_t GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
index_t GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
|
||||
struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
|
||||
{
|
||||
// this is a hack, should query this info from gridwise_gemm instead of duplicate its logic
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr index_t max_lds_align = math::lcm(GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_gemmk_gemmm_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<GemmKPerBlock, GemmMPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_gemmk_gemmn_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<GemmKPerBlock, GemmNPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr index_t a_block_space =
|
||||
math::integer_least_multiple(a_gemmk_gemmm_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
constexpr index_t b_block_space =
|
||||
math::integer_least_multiple(b_gemmk_gemmn_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_space + b_block_space) * sizeof(Float);
|
||||
}
|
||||
|
||||
__device__ void Run(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
#if 0 // debug
|
||||
// sanity-check for vectorized memory load
|
||||
// TODO: this logic may not be correct for bwd-data
|
||||
static_assert(
|
||||
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
|
||||
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
#endif
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t HTilda =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
|
||||
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
|
||||
|
||||
constexpr bool wei_skip_all_out_of_bound_check = true;
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_y_x_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Y,
|
||||
Sequence<YDot, YTilda>,
|
||||
Sequence<ConvStrideH / GcdStrideDilationH, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{},
|
||||
Embed<X,
|
||||
Sequence<XDot, XTilda>,
|
||||
Sequence<ConvStrideW / GcdStrideDilationW, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
#if 1 // debug
|
||||
constexpr bool out_skip_all_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool out_skip_all_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
Embed<Ho,
|
||||
Sequence<YDot, HTilda>,
|
||||
Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{},
|
||||
Embed<Wo,
|
||||
Sequence<XDot, WTilda>,
|
||||
Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<YTilda>{},
|
||||
PassThrough<XTilda>{},
|
||||
Slice<Sequence<HTilda, WTilda>,
|
||||
Sequence<HTildaLeft, WTildaLeft>,
|
||||
Sequence<HTildaRight, WTildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
#if 1 // debug
|
||||
constexpr bool in_skip_all_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool in_skip_all_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
// input tensor
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Hip,
|
||||
Sequence<YTilda, HTilda>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>,
|
||||
in_skip_all_out_of_bound_check>{},
|
||||
Embed<Wip,
|
||||
Sequence<XTilda, WTilda>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>,
|
||||
in_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<YTilda>{},
|
||||
PassThrough<XTilda>{},
|
||||
Slice<Sequence<HTilda, WTilda>,
|
||||
Sequence<HTildaLeft, WTildaLeft>,
|
||||
Sequence<HTildaRight, WTildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
// GEMMs
|
||||
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
|
||||
|
||||
__shared__ Float p_shared_block[shared_block_size];
|
||||
|
||||
static_for<0, YTilda, 1>{}([&](auto iYTilda_) {
|
||||
static_for<0, XTilda, 1>{}([&](auto iXTilda_) {
|
||||
constexpr index_t iYTilda = decltype(iYTilda_){};
|
||||
constexpr index_t iXTilda = decltype(iXTilda_){};
|
||||
|
||||
constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot;
|
||||
constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot;
|
||||
|
||||
// A matrix
|
||||
constexpr auto wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Slice<Sequence<YDot, XDot>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<YDotSlice, XDotSlice>>{},
|
||||
Slice<Sequence<YTilda, XTilda>,
|
||||
Sequence<iYTilda, iXTilda>,
|
||||
Sequence<iYTilda + 1, iXTilda + 1>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{},
|
||||
Merge<Sequence<C, 1, 1>>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// B matrix
|
||||
constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<HTildaSlice>{},
|
||||
PassThrough<WTildaSlice>{},
|
||||
Slice<Sequence<YDot, XDot>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<YDotSlice, XDotSlice>>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<5>{},
|
||||
Sequence<2, 4>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<5>{},
|
||||
Sequence<2, 4>{}));
|
||||
|
||||
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{},
|
||||
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C matrix
|
||||
constexpr auto in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<HTildaSlice>{},
|
||||
PassThrough<WTildaSlice>{},
|
||||
Slice<Sequence<YTilda, XTilda>,
|
||||
Sequence<iYTilda, iXTilda>,
|
||||
Sequence<iYTilda + 1, iXTilda + 1>>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<5>{},
|
||||
Sequence<2, 4>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<5>{},
|
||||
Sequence<2, 4>{}));
|
||||
|
||||
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc,
|
||||
make_tuple(Merge<Sequence<C, 1, 1>>{},
|
||||
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalC_v1<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
decltype(out_gemmk_gemmn_global_desc),
|
||||
decltype(in_gemmm_gemmn_global_desc),
|
||||
InMemoryDataOperation::Set,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global, p_shared_block);
|
||||
|
||||
// is synchronization necessary?
|
||||
__syncthreads();
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -167,9 +167,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
//\todo static_assert for global vector load/store
|
||||
// statc_assert();
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
@@ -179,6 +176,9 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot;
|
||||
constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot;
|
||||
|
||||
constexpr index_t HTilda =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda =
|
||||
@@ -198,10 +198,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft;
|
||||
constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft;
|
||||
|
||||
// A matrix: weight
|
||||
// weight out-of-bound check can be skipped
|
||||
constexpr bool wei_skip_out_of_bound_check = true;
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_y_x_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
@@ -217,15 +217,31 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto wei_k_c_ydotslice_xdotslice_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{},
|
||||
Freeze<Sequence<YTilda, XTilda>, Sequence<iYTilda, iXTilda>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_ydotslice_xdotslice_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{}, PassThrough<C>{}),
|
||||
make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// B matrix: output tensor
|
||||
// TODO sometimes output tensor out-of-bound check can be skipped, find out all such
|
||||
// situations
|
||||
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK
|
||||
constexpr bool out_skip_out_of_bound_check = false;
|
||||
#else
|
||||
//\todo sometimes output tensor out-of-bound check can be skipped, find out all such
|
||||
// situations
|
||||
constexpr bool out_skip_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
@@ -246,8 +262,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<YTilda>{},
|
||||
PassThrough<XTilda>{},
|
||||
PassThrough<YDot>{},
|
||||
PassThrough<XDot>{},
|
||||
Slice<Sequence<HTilda, WTilda>,
|
||||
Sequence<iHTildaLeft, iWTildaLeft>,
|
||||
Sequence<iHTildaRight, iWTildaRight>>{}),
|
||||
@@ -256,14 +272,35 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<HTildaSlice>{},
|
||||
PassThrough<WTildaSlice>{},
|
||||
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
|
||||
|
||||
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{},
|
||||
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C matrix: input tensor
|
||||
// TODO sometimes input out-of-bound check can be skipped, find out all such situations
|
||||
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK
|
||||
constexpr bool in_skip_out_of_bound_check = false;
|
||||
#else
|
||||
//\todo sometimes input out-of-bound check can be skipped, find out all such situations
|
||||
constexpr bool in_skip_out_of_bound_check = true;
|
||||
constexpr bool in_skip_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
// input tensor
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
@@ -291,87 +328,21 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<YTilda>{},
|
||||
PassThrough<XTilda>{},
|
||||
Slice<Sequence<HTilda, WTilda>,
|
||||
Sequence<iHTildaLeft, iWTildaLeft>,
|
||||
Sequence<iHTildaRight, iWTildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
// GEMM
|
||||
constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot;
|
||||
constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot;
|
||||
|
||||
// A matrix
|
||||
constexpr auto wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{},
|
||||
Slice<Sequence<YTilda, XTilda>,
|
||||
Sequence<iYTilda, iXTilda>,
|
||||
Sequence<iYTilda + 1, iXTilda + 1>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{}, Merge<Sequence<C, 1, 1>>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// B matrix
|
||||
constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<HTildaSlice>{},
|
||||
PassThrough<WTildaSlice>{},
|
||||
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
|
||||
|
||||
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{},
|
||||
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C matrix
|
||||
constexpr auto in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<HTildaSlice>{},
|
||||
PassThrough<WTildaSlice>{},
|
||||
Slice<Sequence<YTilda, XTilda>,
|
||||
Sequence<iYTilda, iXTilda>,
|
||||
Sequence<iYTilda + 1, iXTilda + 1>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
|
||||
constexpr auto in_n_c_htildaslice_wtildaslice_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Freeze<Sequence<YTilda, XTilda>, Sequence<iYTilda, iXTilda>>{},
|
||||
Slice<Sequence<HTilda, WTilda>,
|
||||
Sequence<iHTildaLeft, iWTildaLeft>,
|
||||
Sequence<iHTildaRight, iWTildaRight>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc,
|
||||
make_tuple(Merge<Sequence<C, 1, 1>>{}, Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
in_n_c_htildaslice_wtildaslice_global_desc,
|
||||
make_tuple(PassThrough<C>{}, Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
constexpr auto gridwise_gemm =
|
||||
|
||||
@@ -0,0 +1,406 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Number of GEMMs = YTilda * XTilda
|
||||
// GemmM = C
|
||||
// GemmN = N * HTildaSlice * WTildaSlice
|
||||
// GemmK0 = YDotSlice
|
||||
// GemmK1 = XDotSlice
|
||||
// GemmK2 = K
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThread,
|
||||
index_t GemmNPerThread,
|
||||
index_t GemmKPerThread,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t ThreadGemmDataPerRead_GemmM,
|
||||
index_t ThreadGemmDataPerRead_GemmN,
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM,
|
||||
typename GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmM,
|
||||
index_t GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
typename GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN,
|
||||
typename GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN,
|
||||
index_t GemmBBlockCopySrcDataPerRead_GemmK2,
|
||||
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
|
||||
struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetNumberOfGemm()
|
||||
{
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
return YTilda * XTilda;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetGemmSizeImpl(index_t iYTilda, index_t iXTilda)
|
||||
{
|
||||
constexpr index_t N = InGlobalDesc::GetLengths()[0];
|
||||
constexpr index_t Hi = InGlobalDesc::GetLengths()[1];
|
||||
constexpr index_t Wi = InGlobalDesc::GetLengths()[2];
|
||||
constexpr index_t C = InGlobalDesc::GetLengths()[3];
|
||||
|
||||
constexpr index_t Ho = OutGlobalDesc::GetLengths()[1];
|
||||
constexpr index_t Wo = OutGlobalDesc::GetLengths()[2];
|
||||
constexpr index_t K = OutGlobalDesc::GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = WeiGlobalDesc::GetLengths()[1];
|
||||
constexpr index_t X = WeiGlobalDesc::GetLengths()[2];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t HTilda =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
|
||||
constexpr index_t iHTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t iWTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t iHTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t iWTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft;
|
||||
constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft;
|
||||
|
||||
// GemmM and GemmN
|
||||
constexpr index_t GemmM = C;
|
||||
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
|
||||
|
||||
// GemmK is different for each GEMM
|
||||
index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot;
|
||||
index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot;
|
||||
|
||||
index_t GemmK0 = YDotSlice;
|
||||
index_t GemmK1 = XDotSlice;
|
||||
index_t GemmK2 = K;
|
||||
|
||||
return Array<index_t, 5>{GemmM, GemmN, GemmK0, GemmK1, GemmK2};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetGemmSize(index_t gemm_id)
|
||||
{
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
index_t iYTilda = gemm_id / XTilda;
|
||||
index_t iXTilda = gemm_id % XTilda;
|
||||
|
||||
return GetGemmSizeImpl(iYTilda, iXTilda);
|
||||
}
|
||||
|
||||
template <index_t iYTilda, index_t iXTilda>
|
||||
__device__ static void RunImpl(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global)
|
||||
{
|
||||
constexpr auto in_n_hi_wi_c_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_y_x_c_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_ho_wo_k_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_hi_wi_c_global_desc.GetLengths()[0];
|
||||
constexpr index_t Hi = in_n_hi_wi_c_global_desc.GetLengths()[1];
|
||||
constexpr index_t Wi = in_n_hi_wi_c_global_desc.GetLengths()[2];
|
||||
constexpr index_t C = in_n_hi_wi_c_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Ho = out_n_ho_wo_k_global_desc.GetLengths()[1];
|
||||
constexpr index_t Wo = out_n_ho_wo_k_global_desc.GetLengths()[2];
|
||||
constexpr index_t K = out_n_ho_wo_k_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_y_x_c_global_desc.GetLengths()[1];
|
||||
constexpr index_t X = wei_k_y_x_c_global_desc.GetLengths()[2];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot;
|
||||
constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot;
|
||||
|
||||
constexpr index_t HTilda =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
|
||||
constexpr index_t iHTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t iWTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t iHTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t iWTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft;
|
||||
constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft;
|
||||
|
||||
// A matrix: weight
|
||||
// weight out-of-bound check can be skipped
|
||||
constexpr bool wei_skip_out_of_bound_check = true;
|
||||
|
||||
constexpr auto wei_k_ydot_ytilda_xdot_xtilda_c_global_desc = transform_tensor_descriptor(
|
||||
wei_k_y_x_c_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
Embed<Y,
|
||||
Sequence<YDot, YTilda>,
|
||||
Sequence<ConvStrideH / GcdStrideDilationH, 1, 0>,
|
||||
wei_skip_out_of_bound_check>{},
|
||||
Embed<X,
|
||||
Sequence<XDot, XTilda>,
|
||||
Sequence<ConvStrideW / GcdStrideDilationW, 1, 0>,
|
||||
wei_skip_out_of_bound_check>{},
|
||||
PassThrough<C>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
constexpr auto wei_k_ydotslice_xdotslice_c_global_desc = transform_tensor_descriptor(
|
||||
wei_k_ydot_ytilda_xdot_xtilda_c_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<K>{},
|
||||
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{},
|
||||
Freeze<Sequence<YTilda, XTilda>, Sequence<iYTilda, iXTilda>>{},
|
||||
PassThrough<C>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto wei_gemmk0_gemmk1_gemmk2_gemmm_global_desc =
|
||||
reorder_tensor_descriptor_given_lower2upper(wei_k_ydotslice_xdotslice_c_global_desc,
|
||||
Sequence<2, 0, 1, 3>{});
|
||||
|
||||
// B matrix: output tensor
|
||||
// TODO sometimes output tensor out-of-bound check can be skipped, find out all such
|
||||
// situations
|
||||
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK
|
||||
constexpr bool out_skip_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool out_skip_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
constexpr auto out_n_ydot_htilda_xdot_wtilda_k_global_desc = transform_tensor_descriptor(
|
||||
out_n_ho_wo_k_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
Embed<Ho,
|
||||
Sequence<YDot, HTilda>,
|
||||
Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>,
|
||||
out_skip_out_of_bound_check>{},
|
||||
Embed<Wo,
|
||||
Sequence<XDot, WTilda>,
|
||||
Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>,
|
||||
out_skip_out_of_bound_check>{},
|
||||
PassThrough<K>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
constexpr auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_ydot_htilda_xdot_wtilda_k_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{},
|
||||
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{},
|
||||
Slice<Sequence<HTilda, WTilda>,
|
||||
Sequence<iHTildaLeft, iWTildaLeft>,
|
||||
Sequence<iHTildaRight, iWTildaRight>>{},
|
||||
PassThrough<K>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}));
|
||||
|
||||
constexpr auto out_gemmk0_gemmk1_gemmk2_gemmn_global_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k_global_desc,
|
||||
make_tuple(PassThrough<YDotSlice>{},
|
||||
PassThrough<XDotSlice>{},
|
||||
PassThrough<K>{},
|
||||
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// C matrix: input tensor
|
||||
// TODO sometimes input out-of-bound check can be skipped, find out all such situations
|
||||
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK
|
||||
constexpr bool in_skip_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool in_skip_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
constexpr auto in_n_hip_wip_c_global_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_out_of_bound_check>{},
|
||||
PassThrough<C>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_hip_wip_c_global_desc.GetLengths()[1];
|
||||
constexpr index_t Wip = in_n_hip_wip_c_global_desc.GetLengths()[2];
|
||||
|
||||
constexpr auto in_n_ytilda_htilda_xtilda_wtilda_c_global_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
Embed<Hip,
|
||||
Sequence<YTilda, HTilda>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>,
|
||||
in_skip_out_of_bound_check>{},
|
||||
Embed<Wip,
|
||||
Sequence<XTilda, WTilda>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>,
|
||||
in_skip_out_of_bound_check>{},
|
||||
PassThrough<C>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
constexpr auto in_n_htildaslice_wtildaslice_c_global_desc = transform_tensor_descriptor(
|
||||
in_n_ytilda_htilda_xtilda_wtilda_c_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
Freeze<Sequence<YTilda, XTilda>, Sequence<iYTilda, iXTilda>>{},
|
||||
Slice<Sequence<HTilda, WTilda>,
|
||||
Sequence<iHTildaLeft, iWTildaLeft>,
|
||||
Sequence<iHTildaRight, iWTildaRight>>{},
|
||||
PassThrough<C>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_htildaslice_wtildaslice_c_global_desc,
|
||||
make_tuple(PassThrough<C>{}, Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// call GEMM
|
||||
constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalC_v2<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_gemmk0_gemmk1_gemmk2_gemmm_global_desc),
|
||||
decltype(out_gemmk0_gemmk1_gemmk2_gemmn_global_desc),
|
||||
decltype(in_gemmm_gemmn_global_desc),
|
||||
InMemoryDataOperation::Set,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
ThreadGemmDataPerRead_GemmM,
|
||||
ThreadGemmDataPerRead_GemmN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN,
|
||||
Sequence<0, 1, 3, 2>,
|
||||
Sequence<0, 1, 3, 2>,
|
||||
2,
|
||||
GemmBBlockCopySrcDataPerRead_GemmK2,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
Sequence<2, 3, 0, 1>,
|
||||
3,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
|
||||
}
|
||||
|
||||
template <index_t GemmId>
|
||||
__device__ static void Run(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global,
|
||||
Number<GemmId>)
|
||||
{
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t iYTilda = GemmId / XTilda;
|
||||
constexpr index_t iXTilda = GemmId % XTilda;
|
||||
|
||||
static_assert(iYTilda < YTilda && iXTilda < XTilda, "wrong! iYtilda, iXtilda");
|
||||
|
||||
RunImpl<iYTilda, iXTilda>(p_in_global, p_wei_global, p_out_global);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -488,6 +488,49 @@ struct Embed
|
||||
}
|
||||
};
|
||||
|
||||
// LowerLengths: Sequence<...>
|
||||
// LowerFreezePoint: Sequence<...>
|
||||
template <typename LowerLengths, typename LowerFreezePoint>
|
||||
struct Freeze
|
||||
{
|
||||
static constexpr index_t nDimLow = LowerLengths::Size();
|
||||
static constexpr index_t nDimUp = 0;
|
||||
|
||||
using LowerIndex = MultiIndex<nDimLow>;
|
||||
using UpperIndex = MultiIndex<nDimUp>;
|
||||
|
||||
__host__ __device__ explicit constexpr Freeze()
|
||||
{
|
||||
// TODO: sanity check: LowerFreezePoint should be within range of LowerLengths
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<0>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return Sequence<>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& /*idx_up*/)
|
||||
{
|
||||
return to_array(LowerFreezePoint{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateLowerIndexDiff(const UpperIndex& /* idx_up_diff */,
|
||||
const UpperIndex& /* idx_up_old */,
|
||||
const LowerIndex& /* idx_low_old */)
|
||||
{
|
||||
return make_zero_array<index_t, nDimLow>();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t LowerLength, index_t VectorSize>
|
||||
struct Vectorize
|
||||
{
|
||||
|
||||
@@ -376,5 +376,400 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerThread,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t MLevel0Cluster,
|
||||
index_t NLevel0Cluster,
|
||||
index_t MLevel1Cluster,
|
||||
index_t NLevel1Cluster,
|
||||
index_t ThreadGemmAThreadCopySrcDataPerRead_M,
|
||||
index_t ThreadGemmBThreadCopySrcDataPerRead_N,
|
||||
typename ABlockCopyThreadSliceLengths_K0_K1_K2_M,
|
||||
typename ABlockCopyThreadClusterLengths_K0_K1_K2_M,
|
||||
typename ABlockCopyThreadClusterArrangeOrder,
|
||||
typename ABlockCopySrcAccessOrder,
|
||||
index_t ABlockCopySrcVectorReadDim,
|
||||
index_t ABlockCopySrcDataPerRead,
|
||||
index_t ABlockCopyDstDataPerWrite_M,
|
||||
typename BBlockCopyThreadSliceLengths_K0_K1_K2_N,
|
||||
typename BBlockCopyThreadClusterLengths_K0_K1_K2_N,
|
||||
typename BBlockCopyThreadClusterArrangeOrder,
|
||||
typename BBlockCopySrcAccessOrder,
|
||||
index_t BBlockCopySrcVectorReadDim,
|
||||
index_t BBlockCopySrcDataPerRead,
|
||||
index_t BBlockCopyDstDataPerWrite_N,
|
||||
typename CThreadCopySrcDstAccessOrder,
|
||||
index_t CThreadCopySrcDstVectorReadWriteDim,
|
||||
index_t CThreadCopyDstDataPerWrite>
|
||||
struct GridwiseGemmTransposedANormalBNormalC_v2
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
ThreadGemmAThreadCopySrcDataPerRead_M,
|
||||
ThreadGemmBThreadCopySrcDataPerRead_N);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr index_t a_block_space =
|
||||
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
constexpr index_t b_block_space =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_space + b_block_space) * sizeof(Float);
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_a_global,
|
||||
const Float* __restrict__ p_b_global,
|
||||
Float* __restrict__ p_c_global,
|
||||
Float* __restrict__ p_shared_block) const
|
||||
{
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
constexpr auto False = integral_constant<bool, false>{};
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr auto a_k0_k1_k2_m_global_desc = AGlobalDesc{};
|
||||
constexpr auto b_k0_k1_k2_n_global_desc = BGlobalDesc{};
|
||||
constexpr auto c_m_n_global_desc = CGlobalDesc{};
|
||||
|
||||
constexpr auto K0 = a_k0_k1_k2_m_global_desc.GetLengths()[0];
|
||||
constexpr auto K1 = a_k0_k1_k2_m_global_desc.GetLengths()[1];
|
||||
constexpr auto K = a_k0_k1_k2_m_global_desc.GetLengths()[2];
|
||||
constexpr auto M = c_m_n_global_desc.GetLengths()[0];
|
||||
constexpr auto N = c_m_n_global_desc.GetLengths()[1];
|
||||
|
||||
// don't do anything if K == 0
|
||||
if(K == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// lds max alignment
|
||||
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
ThreadGemmAThreadCopySrcDataPerRead_M,
|
||||
ThreadGemmBThreadCopySrcDataPerRead_N);
|
||||
|
||||
// divide block work by [M, N]
|
||||
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t MBlockWork = M / MPerBlock;
|
||||
constexpr index_t NBlockWork = N / NPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_cluster_descriptor(Sequence<MBlockWork, NBlockWork>{});
|
||||
|
||||
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
|
||||
|
||||
const index_t m_block_data_on_global = block_work_id[0] * MPerBlock;
|
||||
const index_t n_block_data_on_global = block_work_id[1] * NPerBlock;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_k1_k2_m_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<1, 1, KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(a_k0_k1_k2_m_global_desc),
|
||||
decltype(a_k0_k1_k2_m_block_desc),
|
||||
decltype(a_k0_k1_k2_m_block_desc.GetLengths()),
|
||||
ABlockCopyThreadSliceLengths_K0_K1_K2_M,
|
||||
ABlockCopyThreadClusterLengths_K0_K1_K2_M,
|
||||
ABlockCopyThreadClusterArrangeOrder,
|
||||
ABlockCopySrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
ABlockCopySrcVectorReadDim,
|
||||
3,
|
||||
ABlockCopySrcDataPerRead,
|
||||
ABlockCopyDstDataPerWrite_M,
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Lds,
|
||||
InMemoryDataOperation::Set>(
|
||||
{0, 0, 0, m_block_data_on_global}, {0, 0, 0, 0});
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_k1_k2_n_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<1, 1, KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(b_k0_k1_k2_n_global_desc),
|
||||
decltype(b_k0_k1_k2_n_block_desc),
|
||||
decltype(b_k0_k1_k2_n_block_desc.GetLengths()),
|
||||
BBlockCopyThreadSliceLengths_K0_K1_K2_N,
|
||||
BBlockCopyThreadClusterLengths_K0_K1_K2_N,
|
||||
BBlockCopyThreadClusterArrangeOrder,
|
||||
BBlockCopySrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
BBlockCopySrcVectorReadDim,
|
||||
3,
|
||||
BBlockCopySrcDataPerRead,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Lds,
|
||||
InMemoryDataOperation::Set>(
|
||||
{0, 0, 0, n_block_data_on_global}, {0, 0, 0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[KPerBlocl, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
unfold_tensor_descriptor(a_k0_k1_k2_m_block_desc, I0, I2));
|
||||
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
unfold_tensor_descriptor(b_k0_k1_k2_n_block_desc, I0, I2));
|
||||
|
||||
// sanity check
|
||||
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
|
||||
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
|
||||
constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * MPerThread>{}, Number<GemmNRepeat * NPerThread>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_k_m_block_mtx_desc),
|
||||
decltype(b_k_n_block_mtx_desc),
|
||||
decltype(c_m0m1_n0n1_thread_mtx_desc),
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
MLevel0Cluster,
|
||||
NLevel0Cluster,
|
||||
MLevel1Cluster,
|
||||
NLevel1Cluster,
|
||||
ThreadGemmAThreadCopySrcDataPerRead_M,
|
||||
ThreadGemmBThreadCopySrcDataPerRead_N>{};
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr index_t a_block_space =
|
||||
math::integer_least_multiple(a_k0_k1_k2_m_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
constexpr index_t b_block_space =
|
||||
math::integer_least_multiple(b_k0_k1_k2_n_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
Float* p_a_block_double = p_shared_block;
|
||||
Float* p_b_block_double = p_shared_block + 2 * a_block_space;
|
||||
|
||||
// register allocation for output
|
||||
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread);
|
||||
|
||||
for(index_t k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
for(index_t k1 = 0; k1 < K1; ++k1)
|
||||
{
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.Run(p_a_global, p_a_block_double);
|
||||
b_blockwise_copy.Run(p_b_global, p_b_block_double);
|
||||
}
|
||||
|
||||
constexpr auto a_block_slice_copy_steps = Sequence<0, 0, KPerBlock, 0>{};
|
||||
constexpr auto b_block_slice_copy_steps = Sequence<0, 0, KPerBlock, 0>{};
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
|
||||
k_block_data_begin += 2 * KPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_a_block_now =
|
||||
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
|
||||
Float* p_b_block_now =
|
||||
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
|
||||
|
||||
Float* p_a_block_next =
|
||||
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
|
||||
Float* p_b_block_next =
|
||||
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
|
||||
|
||||
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
|
||||
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
|
||||
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
|
||||
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
|
||||
|
||||
if(has_two_iteration_left) // if has 2 iteration left
|
||||
{
|
||||
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
|
||||
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
|
||||
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
|
||||
p_a_block_double + a_block_space);
|
||||
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
|
||||
p_b_block_double + b_block_space);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_a_block_double + a_block_space,
|
||||
p_b_block_double + b_block_space,
|
||||
p_c_thread);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
|
||||
}
|
||||
}
|
||||
|
||||
// reset slice windoww on K2 dimension, then move forward on K1 dimension
|
||||
a_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 0, K - KPerBlock, 0>{}, False);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 0, K - KPerBlock, 0>{}, False);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 1, 0, 0>{}, True);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 1, 0, 0>{}, True);
|
||||
}
|
||||
|
||||
// reset slice windoww on K1 dimension, then move forward on K0 dimension
|
||||
a_blockwise_copy.MoveSrcSliceWindow(Sequence<0, K1, 0, 0>{}, False);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(Sequence<0, K1, 0, 0>{}, False);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(Sequence<1, 0, 0, 0>{}, True);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(Sequence<1, 0, 0, 0>{}, True);
|
||||
}
|
||||
|
||||
// input: register to global memory
|
||||
{
|
||||
constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t M0 = M / M1;
|
||||
|
||||
constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t N0 = N / N1;
|
||||
|
||||
// define input tensor descriptor for threadwise copy
|
||||
// thread input tensor, src of threadwise copy
|
||||
constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed(
|
||||
Sequence<GemmMRepeat, MPerThread, GemmNRepeat, NPerThread>{});
|
||||
|
||||
constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor(
|
||||
c_m_n_global_desc,
|
||||
make_tuple(UnMerge<Sequence<M0, M1>>{}, UnMerge<Sequence<N0, N1>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
// calculate origin of thread input tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t m_thread_data_on_global =
|
||||
m_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t n_thread_data_on_global =
|
||||
n_block_data_on_global + c_thread_mtx_on_block.col;
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_n0_n1_thread_desc),
|
||||
decltype(c_m0_m1_n0_n1_global_desc),
|
||||
decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()),
|
||||
CThreadCopySrcDstAccessOrder,
|
||||
CThreadCopySrcDstVectorReadWriteDim,
|
||||
1,
|
||||
CThreadCopyDstDataPerWrite,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Global,
|
||||
CGlobalMemoryDataOperation>(
|
||||
{0, 0, 0, 0},
|
||||
{m_thread_data_on_global / M1,
|
||||
m_thread_data_on_global % M1,
|
||||
n_thread_data_on_global / N1,
|
||||
n_thread_data_on_global % N1})
|
||||
.Run(p_c_thread, p_c_global);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_a_global,
|
||||
const Float* __restrict__ p_b_global,
|
||||
Float* __restrict__ p_c_global) const
|
||||
{
|
||||
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
|
||||
|
||||
__shared__ Float p_shared_block[shared_block_size];
|
||||
|
||||
Run(p_a_global, p_b_global, p_c_global, p_shared_block);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -12,6 +12,7 @@ struct Array
|
||||
using type = Array<TData, NSize>;
|
||||
using data_type = TData;
|
||||
|
||||
// TODO: implement empty Array
|
||||
index_t mData[NSize];
|
||||
|
||||
__host__ __device__ explicit constexpr Array() {}
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
|
||||
#if CK_USE_AMD_XDLOPS
|
||||
#include "amd_xdlops.hpp"
|
||||
#include "amd_xdlops_inline_asm.hpp"
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
@@ -108,8 +108,12 @@ struct SetData
|
||||
{
|
||||
const auto zeros = vector_t(0);
|
||||
|
||||
amd_buffer_store<T, DataPerAccess>(
|
||||
src_valid ? &(p_src[src_offset]) : &zeros, p_dst, dst_offset, dst_valid, dst_range);
|
||||
amd_buffer_store<T, DataPerAccess>(src_valid ? &(p_src[src_offset])
|
||||
: reinterpret_cast<const T*>(&zeros),
|
||||
p_dst,
|
||||
dst_offset,
|
||||
dst_valid,
|
||||
dst_range);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
@@ -145,19 +149,17 @@ struct AtomicAddData
|
||||
template <>
|
||||
__device__ void Run<AddressSpace::Vgpr, AddressSpace::Global>(const T* p_src,
|
||||
index_t src_offset,
|
||||
bool src_valid,
|
||||
index_t /* src_range */,
|
||||
bool src_valid T* p_dst,
|
||||
T* p_dst,
|
||||
index_t dst_offset,
|
||||
bool dst_valid,
|
||||
index_t dst_range) const
|
||||
{
|
||||
const auto zeros = vector_t(0);
|
||||
|
||||
amd_buffer_atomic_add<T, DataPerAccess>(src_valid ? &(p_src[src_offset]) : &zeros,
|
||||
p_dst,
|
||||
dst_offset,
|
||||
dst_valid,
|
||||
index_t dst_range);
|
||||
amd_buffer_atomic_add<T, DataPerAccess>(
|
||||
src_valid ? &(p_src[src_offset]) : &zeros, p_dst, dst_offset, dst_valid, dst_range);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user