mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
Added bwd data v3r1 v4r1, tweaking v1 (#10)
* Added bwd data v3r1: breaking down compute into a series of load balanced GEMM, and launch in a single kernel * Added bwd data v4r1: like v3r1, but launch GEMMs in multiple kernels * Tweaked v1r1 and v1r2 (atomic) on AMD GPU
This commit is contained in:
@@ -111,8 +111,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
Embed<Hi + InLeftPads::At(0) + InRightPads::At(0),
|
||||
Sequence<Y, Ho>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Wi + InLeftPads::At(1) + InRightPads::At(1),
|
||||
Sequence<X, Wo>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
@@ -122,7 +126,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// GEMM: atomic add
|
||||
// GEMM
|
||||
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
|
||||
? InMemoryDataOperation::none
|
||||
: InMemoryDataOperation::atomic_add;
|
||||
|
||||
constexpr auto gridwise_gemm =
|
||||
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
|
||||
BlockSize,
|
||||
@@ -131,7 +139,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
decltype(wei_k_e_global_desc),
|
||||
decltype(out_k_b_global_desc),
|
||||
decltype(in_e_b_global_desc),
|
||||
InMemoryDataOperation::atomic_add,
|
||||
in_memory_op,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
|
||||
@@ -352,8 +352,16 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
|
||||
}
|
||||
}
|
||||
|
||||
// input: register to global memory, atomic add
|
||||
{
|
||||
#if 1 // debug
|
||||
// input: register to global memory, atomic add
|
||||
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
|
||||
? InMemoryDataOperation::none
|
||||
: InMemoryDataOperation::atomic_add;
|
||||
#else
|
||||
constexpr auto in_memory_op = InMemoryDataOperation::atomic_add;
|
||||
#endif
|
||||
|
||||
constexpr index_t E1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t E0 = E / E1;
|
||||
|
||||
@@ -378,8 +386,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(UnMerge<Sequence<N0, N1>>{},
|
||||
UnMerge<Sequence<C0, C1>>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
Embed<Hi + LeftPads::At(0) + RightPads::At(0),
|
||||
Sequence<Y, Ho>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Wi + LeftPads::At(1) + RightPads::At(1),
|
||||
Sequence<X, Wo>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
|
||||
|
||||
@@ -422,13 +434,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
|
||||
InThreadCopyDstDataPerWrite_B,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::global,
|
||||
InMemoryDataOperation::atomic_add>({0, 0, 0, 0, 0, 0},
|
||||
{e_thread_data_on_global / E1,
|
||||
e_thread_data_on_global % E1,
|
||||
0,
|
||||
b_thread_data_on_global / B1,
|
||||
b_thread_data_on_global % B1,
|
||||
0})
|
||||
in_memory_op>({0, 0, 0, 0, 0, 0},
|
||||
{e_thread_data_on_global / E1,
|
||||
e_thread_data_on_global % E1,
|
||||
0,
|
||||
b_thread_data_on_global / B1,
|
||||
b_thread_data_on_global % B1,
|
||||
0})
|
||||
.Run(p_in_thread, p_in_global);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
namespace ck {
|
||||
|
||||
// GemmM = C * Ytilda * Xtilda;
|
||||
// GemmN = N * Htilda * Wtilda;
|
||||
// GemmN = N * HtildaNonZero * WtildaNonZero;
|
||||
// GemmK = K * Ydot * Xdot;
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
@@ -20,8 +20,8 @@ template <index_t GridSize,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InputLeftPads,
|
||||
typename InputRightPads,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
@@ -71,6 +71,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
#if 0 // debug
|
||||
// sanity-check for vectorized memory load
|
||||
// TODO: this logic may not be correct for bwd-data
|
||||
static_assert(
|
||||
@@ -78,6 +79,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
@@ -88,30 +90,34 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
|
||||
constexpr index_t right_pad_ho = (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda);
|
||||
constexpr index_t right_pad_wo = (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda);
|
||||
constexpr index_t Htilda =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t Htilda = Ho + right_pad_ho;
|
||||
constexpr index_t Wtilda = Wo + right_pad_wo;
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_k_c_yp_xp_global_desc = transform_tensor_descriptor(
|
||||
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_y_x_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Y, X>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<Ydot * Ytilda - Y, Xdot * Xtilda - X>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_yp_xp_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Ydot, Ytilda>,
|
||||
Embed<Y,
|
||||
Sequence<Ydot, Ytilda>,
|
||||
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>>{},
|
||||
Embed<Sequence<Xdot, Xtilda>,
|
||||
Embed<X,
|
||||
Sequence<Xdot, Xtilda>,
|
||||
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
@@ -123,55 +129,96 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_n_k_hop_wop_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
Pad<Sequence<Ho, Wo>, Sequence<0, 0>, Sequence<right_pad_ho, right_pad_wo>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_hop_wop_global_desc,
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
Embed<Sequence<Ydot, Htilda>,
|
||||
Embed<Ho,
|
||||
Sequence<Ydot, Htilda>,
|
||||
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{},
|
||||
Embed<Sequence<Xdot, Wtilda>,
|
||||
Embed<Wo,
|
||||
Sequence<Xdot, Wtilda>,
|
||||
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
|
||||
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<Ytilda>{},
|
||||
PassThrough<Xtilda>{},
|
||||
Slice<Sequence<Htilda, Wtilda>,
|
||||
Sequence<HtildaLeft, WtildaLeft>,
|
||||
Sequence<HtildaRight, WtildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
constexpr auto out_gemmk_gemmn_global_desc =
|
||||
transform_tensor_descriptor(out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
|
||||
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{},
|
||||
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
#if 1 // debug
|
||||
constexpr bool in_skip_all_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool in_skip_all_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
// input tensor
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, InputLeftPads, InputRightPads>{}),
|
||||
make_tuple(
|
||||
PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Ytilda, Htilda>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<Xtilda, Wtilda>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
Embed<Hip,
|
||||
Sequence<Ytilda, Htilda>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>,
|
||||
in_skip_all_out_of_bound_check>{},
|
||||
Embed<Wip,
|
||||
Sequence<Xtilda, Wtilda>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>,
|
||||
in_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<Ytilda>{},
|
||||
PassThrough<Xtilda>{},
|
||||
Slice<Sequence<Htilda, Wtilda>,
|
||||
Sequence<HtildaLeft, WtildaLeft>,
|
||||
Sequence<HtildaRight, WtildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
constexpr auto in_gemmm_gemmn_global_desc =
|
||||
transform_tensor_descriptor(in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{},
|
||||
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// GEMM
|
||||
constexpr auto gridwise_gemm =
|
||||
|
||||
@@ -0,0 +1,393 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V3R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Ytilda*Xtilda number of GEMMs
|
||||
// GemmM = C;
|
||||
// GemmN = N * HtildaNonZero * WtildaNonZero;
|
||||
// GemmK = K * YdotNonZero * XdotNonZero;
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmThreadGemmDataPerReadM,
|
||||
index_t GemmThreadGemmDataPerReadN,
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmM,
|
||||
index_t GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
index_t GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
|
||||
struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
|
||||
{
|
||||
// this is a hack, should query this info from gridwise_gemm instead of duplicate its logic
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr index_t max_lds_align = math::lcm(GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_gemmk_gemmm_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<GemmKPerBlock, GemmMPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_gemmk_gemmn_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<GemmKPerBlock, GemmNPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr index_t a_block_space =
|
||||
math::integer_least_multiple(a_gemmk_gemmm_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
constexpr index_t b_block_space =
|
||||
math::integer_least_multiple(b_gemmk_gemmn_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_space + b_block_space) * sizeof(Float);
|
||||
}
|
||||
|
||||
__device__ void Run(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
#if 0 // debug
|
||||
// sanity-check for vectorized memory load
|
||||
// TODO: this logic may not be correct for bwd-data
|
||||
static_assert(
|
||||
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
|
||||
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
|
||||
constexpr index_t Htilda =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
|
||||
constexpr bool wei_skip_all_out_of_bound_check = true;
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_y_x_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Y,
|
||||
Sequence<Ydot, Ytilda>,
|
||||
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{},
|
||||
Embed<X,
|
||||
Sequence<Xdot, Xtilda>,
|
||||
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
#if 1 // debug
|
||||
constexpr bool out_skip_all_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool out_skip_all_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
Embed<Ho,
|
||||
Sequence<Ydot, Htilda>,
|
||||
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{},
|
||||
Embed<Wo,
|
||||
Sequence<Xdot, Wtilda>,
|
||||
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<Ytilda>{},
|
||||
PassThrough<Xtilda>{},
|
||||
Slice<Sequence<Htilda, Wtilda>,
|
||||
Sequence<HtildaLeft, WtildaLeft>,
|
||||
Sequence<HtildaRight, WtildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
#if 1 // debug
|
||||
constexpr bool in_skip_all_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool in_skip_all_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
// input tensor
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Hip,
|
||||
Sequence<Ytilda, Htilda>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>,
|
||||
in_skip_all_out_of_bound_check>{},
|
||||
Embed<Wip,
|
||||
Sequence<Xtilda, Wtilda>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>,
|
||||
in_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<Ytilda>{},
|
||||
PassThrough<Xtilda>{},
|
||||
Slice<Sequence<Htilda, Wtilda>,
|
||||
Sequence<HtildaLeft, WtildaLeft>,
|
||||
Sequence<HtildaRight, WtildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
// GEMMs
|
||||
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
|
||||
|
||||
__shared__ Float p_shared_block[shared_block_size];
|
||||
|
||||
#if 1 // debug
|
||||
static_for<0, Ytilda, 1>{}([&](auto ytilda_) {
|
||||
static_for<0, Xtilda, 1>{}([&](auto xtilda_) {
|
||||
#else
|
||||
static_for<0, 1, 1>{}([&](auto ytilda_) {
|
||||
static_for<0, 1, 1>{}([&](auto xtilda_) {
|
||||
#endif
|
||||
constexpr index_t ytilda = decltype(ytilda_){};
|
||||
constexpr index_t xtilda = decltype(xtilda_){};
|
||||
|
||||
constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot;
|
||||
constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot;
|
||||
|
||||
// A matrix
|
||||
constexpr auto wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Slice<Sequence<Ydot, Xdot>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<YdotNonZero, XdotNonZero>>{},
|
||||
Slice<Sequence<Ytilda, Xtilda>,
|
||||
Sequence<ytilda, xtilda>,
|
||||
Sequence<ytilda + 1, xtilda + 1>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{},
|
||||
Merge<Sequence<C, 1, 1>>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// B matrix
|
||||
constexpr auto out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<HtildaTrim>{},
|
||||
PassThrough<WtildaTrim>{},
|
||||
Slice<Sequence<Ydot, Xdot>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<YdotNonZero, XdotNonZero>>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<5>{},
|
||||
Sequence<2, 4>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<5>{},
|
||||
Sequence<2, 4>{}));
|
||||
|
||||
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{},
|
||||
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C matrix
|
||||
constexpr auto in_n_c_1_htildatrim_1_wtildatrim_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<HtildaTrim>{},
|
||||
PassThrough<WtildaTrim>{},
|
||||
Slice<Sequence<Ytilda, Xtilda>,
|
||||
Sequence<ytilda, xtilda>,
|
||||
Sequence<ytilda + 1, xtilda + 1>>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<5>{},
|
||||
Sequence<2, 4>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<5>{},
|
||||
Sequence<2, 4>{}));
|
||||
|
||||
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_1_htildatrim_1_wtildatrim_global_desc,
|
||||
make_tuple(Merge<Sequence<C, 1, 1>>{},
|
||||
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalC_v1<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
decltype(out_gemmk_gemmn_global_desc),
|
||||
decltype(in_gemmm_gemmn_global_desc),
|
||||
InMemoryDataOperation::none,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global, p_shared_block);
|
||||
|
||||
// is synchronization necessary?
|
||||
__syncthreads();
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,333 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = C
|
||||
// GemmN = N * Htilda * Wtilda;
|
||||
// GemmK = K * YdotNonZero * XdotNonZero
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t Iter_ytilda,
|
||||
index_t Iter_xtilda,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmThreadGemmDataPerReadM,
|
||||
index_t GemmThreadGemmDataPerReadN,
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmM,
|
||||
index_t GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
index_t GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
|
||||
struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
{
|
||||
__device__ void Run(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
#if 0 // debug
|
||||
// sanity-check for vectorized memory load
|
||||
// TODO: this logic may not be correct for bwd-data
|
||||
static_assert(
|
||||
(Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDstDataPerWrite_GemmN1 == 1)) &&
|
||||
(X == 1 || ConvDilationW % GemmCThreadCopyDstDataPerWrite_GemmN1 == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
|
||||
constexpr index_t Htilda =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
|
||||
constexpr bool wei_skip_all_out_of_bound_check = true;
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_y_x_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Y,
|
||||
Sequence<Ydot, Ytilda>,
|
||||
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{},
|
||||
Embed<X,
|
||||
Sequence<Xdot, Xtilda>,
|
||||
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
#if 1 // debug
|
||||
constexpr bool out_skip_all_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool out_skip_all_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
Embed<Ho,
|
||||
Sequence<Ydot, Htilda>,
|
||||
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{},
|
||||
Embed<Wo,
|
||||
Sequence<Xdot, Wtilda>,
|
||||
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<Ytilda>{},
|
||||
PassThrough<Xtilda>{},
|
||||
Slice<Sequence<Htilda, Wtilda>,
|
||||
Sequence<HtildaLeft, WtildaLeft>,
|
||||
Sequence<HtildaRight, WtildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
#if 1 // debug
|
||||
constexpr bool in_skip_all_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool in_skip_all_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
// input tensor
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Hip,
|
||||
Sequence<Ytilda, Htilda>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>,
|
||||
in_skip_all_out_of_bound_check>{},
|
||||
Embed<Wip,
|
||||
Sequence<Xtilda, Wtilda>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>,
|
||||
in_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<Ytilda>{},
|
||||
PassThrough<Xtilda>{},
|
||||
Slice<Sequence<Htilda, Wtilda>,
|
||||
Sequence<HtildaLeft, WtildaLeft>,
|
||||
Sequence<HtildaRight, WtildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
// GEMM
|
||||
constexpr index_t ytilda = Iter_ytilda;
|
||||
constexpr index_t xtilda = Iter_xtilda;
|
||||
|
||||
constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot;
|
||||
constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot;
|
||||
|
||||
// A matrix
|
||||
constexpr auto wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Slice<Sequence<Ydot, Xdot>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<YdotNonZero, XdotNonZero>>{},
|
||||
Slice<Sequence<Ytilda, Xtilda>,
|
||||
Sequence<ytilda, xtilda>,
|
||||
Sequence<ytilda + 1, xtilda + 1>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_YdotNonZero_1_XdotNonZero_1_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{}, Merge<Sequence<C, 1, 1>>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// B matrix
|
||||
constexpr auto out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htildatrim_xdot_wtildatrim_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<HtildaTrim>{},
|
||||
PassThrough<WtildaTrim>{},
|
||||
Slice<Sequence<Ydot, Xdot>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<YdotNonZero, XdotNonZero>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
|
||||
|
||||
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_YdotNonZero_htildatrim_XdotNonZero_wtildatrim_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YdotNonZero, XdotNonZero>>{},
|
||||
Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C matrix
|
||||
constexpr auto in_n_c_1_htildatrim_1_wtildatrim_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htildatrim_xtilda_wtildatrim_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
PassThrough<HtildaTrim>{},
|
||||
PassThrough<WtildaTrim>{},
|
||||
Slice<Sequence<Ytilda, Xtilda>,
|
||||
Sequence<ytilda, xtilda>,
|
||||
Sequence<ytilda + 1, xtilda + 1>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
|
||||
|
||||
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_1_htildatrim_1_wtildatrim_global_desc,
|
||||
make_tuple(Merge<Sequence<C, 1, 1>>{}, Merge<Sequence<N, HtildaTrim, WtildaTrim>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
constexpr auto gridwise_gemm =
|
||||
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
decltype(out_gemmk_gemmn_global_desc),
|
||||
decltype(in_gemmm_gemmn_global_desc),
|
||||
InMemoryDataOperation::none,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -181,12 +181,15 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(UnMerge<Sequence<N0, N1, N2>>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
|
||||
|
||||
|
||||
@@ -97,12 +97,15 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
|
||||
@@ -41,15 +41,17 @@ struct PassThrough
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// LowerLengths: Sequence<...>
|
||||
template <typename LowerLengths, typename LeftPads, typename RightPads>
|
||||
template <typename LowerLengths,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
bool SkipIsValidCheck = false>
|
||||
struct Pad
|
||||
{
|
||||
static constexpr index_t nDim = LowerLengths::Size();
|
||||
@@ -57,6 +59,13 @@ struct Pad
|
||||
using LowerIndex = MultiIndex<nDim>;
|
||||
using UpperIndex = MultiIndex<nDim>;
|
||||
|
||||
__host__ __device__ explicit constexpr Pad()
|
||||
{
|
||||
static_assert(LowerLengths::GetSize() == nDim && LeftPads::GetSize() == nDim &&
|
||||
RightPads::GetSize() == nDim,
|
||||
"wrong! # of dimensions not consistent");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; }
|
||||
@@ -81,20 +90,83 @@ struct Pad
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ constexpr bool
|
||||
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up) const
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
#if 1 // debug
|
||||
if(SkipIsValidCheck)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
bool flag = true;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
flag = flag && (idx_up[idim] >= LeftPads::At(idim)) &&
|
||||
(idx_up[idim] < LeftPads::At(idim) + LowerLengths::At(idim));
|
||||
});
|
||||
for(index_t i = 0; i < nDim; ++i)
|
||||
{
|
||||
flag = flag && LeftPads::At(i) == 0 && RightPads::At(i) == 0;
|
||||
}
|
||||
|
||||
return flag;
|
||||
}
|
||||
};
|
||||
|
||||
// LowerLengths: Sequence<...>
|
||||
// SliceBegins: Sequence<...>
|
||||
// SliceEnds: Sequence<...>
|
||||
template <typename LowerLengths, typename SliceBegins, typename SliceEnds>
|
||||
struct Slice
|
||||
{
|
||||
static constexpr index_t nDim = LowerLengths::Size();
|
||||
|
||||
using LowerIndex = MultiIndex<nDim>;
|
||||
using UpperIndex = MultiIndex<nDim>;
|
||||
|
||||
__host__ __device__ explicit constexpr Slice()
|
||||
{
|
||||
static_assert(LowerLengths::GetSize() == nDim && SliceBegins::GetSize() == nDim &&
|
||||
SliceEnds::GetSize() == nDim,
|
||||
"wrong! # of dimensions not consistent");
|
||||
|
||||
#if 0
|
||||
// TODO: would not compile, error on constexpr
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
static_assert(SliceBegins::At(idim) <= SliceEnds::At(idim) &&
|
||||
SliceBegins::At(idim) >= 0 &&
|
||||
SliceEnds::At(idim) <= LowerLengths::At(idim),
|
||||
"wrong! Slice config is wrong");
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
{
|
||||
return SliceEnds{} - SliceBegins{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
|
||||
{
|
||||
return idx_up + SliceBegins{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
|
||||
const UpperIndex& /* idx_up_old */,
|
||||
const LowerIndex& /* idx_low_old */)
|
||||
{
|
||||
return idx_up_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// LowerLengths: Sequence<...>
|
||||
template <typename LowerLengths>
|
||||
struct Merge
|
||||
@@ -165,85 +237,101 @@ struct Merge
|
||||
const UpperIndex& /* idx_up_old */,
|
||||
const LowerIndex& idx_low_old)
|
||||
{
|
||||
// do nothing if idx_up_diff == 0
|
||||
if(idx_up_diff[0] == 0)
|
||||
{
|
||||
return make_zero_array<index_t, nDimLow>();
|
||||
}
|
||||
|
||||
// CalculateLowerIndex(idx_up_diff) has multiple integer divisions.
|
||||
// If idx_up_diff is known at compile-time, the calculation can
|
||||
// be done at compile-time. However, if idx_up_diff is only known
|
||||
// at run-time, then the calculation will also be computed at
|
||||
// run-time, and can be very expensive.
|
||||
LowerIndex idx_low_new = idx_low_old + CalculateLowerIndex(idx_up_diff);
|
||||
|
||||
if(idx_up_diff[0] > 0)
|
||||
else
|
||||
{
|
||||
bool carry = false;
|
||||
// CalculateLowerIndex(idx_up_diff) has multiple integer divisions.
|
||||
// If idx_up_diff is known at compile-time, the calculation can
|
||||
// be done at compile-time. However, if idx_up_diff is only known
|
||||
// at run-time, then the calculation will also be computed at
|
||||
// run-time, and can be very expensive.
|
||||
LowerIndex idx_low_diff_tmp = CalculateLowerIndex(idx_up_diff);
|
||||
|
||||
// do carry check in reversed order, starting from lowest dimension
|
||||
// don't check the highest dimension
|
||||
static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) {
|
||||
constexpr index_t i = nDimLow - 1 - ireverse;
|
||||
// find out the last low dimension that changed
|
||||
index_t last_changed_low_dim = 0;
|
||||
|
||||
static_for<0, nDimLow, 1>{}([&](auto i) {
|
||||
if(idx_low_diff_tmp[i] != 0)
|
||||
{
|
||||
last_changed_low_dim = i;
|
||||
}
|
||||
});
|
||||
|
||||
LowerIndex idx_low_new = idx_low_old + idx_low_diff_tmp;
|
||||
|
||||
if(idx_up_diff[0] > 0)
|
||||
{
|
||||
// do carry check on each low dimension in reversed order
|
||||
// starting from the first digit that changed
|
||||
// don't check the highest dimension
|
||||
bool carry = false;
|
||||
|
||||
static_for<nDimLow - 1, 0, -1>{}([&](auto i) {
|
||||
if(i <= last_changed_low_dim)
|
||||
{
|
||||
if(carry)
|
||||
{
|
||||
++idx_low_new(i);
|
||||
}
|
||||
|
||||
carry = false;
|
||||
|
||||
if(idx_low_new[i] >= LowerLengths::At(i))
|
||||
{
|
||||
idx_low_new(i) -= LowerLengths::At(i);
|
||||
carry = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// highest dimension, no out-of-bound check
|
||||
if(carry)
|
||||
{
|
||||
++idx_low_new(i);
|
||||
++idx_low_new(0);
|
||||
}
|
||||
|
||||
carry = false;
|
||||
|
||||
if(idx_low_new[i] >= LowerLengths::At(i))
|
||||
{
|
||||
idx_low_new(i) -= LowerLengths::At(i);
|
||||
carry = true;
|
||||
}
|
||||
});
|
||||
|
||||
// highest dimension, no out-of-bound check
|
||||
if(carry)
|
||||
{
|
||||
++idx_low_new(0);
|
||||
}
|
||||
}
|
||||
else if(idx_up_diff[0] < 0)
|
||||
{
|
||||
bool borrow = false;
|
||||
else
|
||||
{
|
||||
// do borrow check on each low dimension in reversed order
|
||||
// starting from the first digit that changed
|
||||
// don't check the highest dimension
|
||||
bool borrow = false;
|
||||
|
||||
// do borrow check in reversed order, starting from lowest dimension
|
||||
// don't check the highest dimension
|
||||
static_for<0, nDimLow - 1, 1>{}([&](auto ireverse) {
|
||||
constexpr index_t i = nDimLow - 1 - ireverse;
|
||||
static_for<nDimLow - 1, 0, -1>{}([&](auto i) {
|
||||
if(i <= last_changed_low_dim)
|
||||
{
|
||||
if(borrow)
|
||||
{
|
||||
--idx_low_new(i);
|
||||
}
|
||||
|
||||
borrow = false;
|
||||
|
||||
if(idx_low_new[i] < 0)
|
||||
{
|
||||
idx_low_new(i) += LowerLengths::At(i);
|
||||
borrow = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// highest dimension, no out-of-bound check
|
||||
if(borrow)
|
||||
{
|
||||
--idx_low_new(i);
|
||||
--idx_low_new(0);
|
||||
}
|
||||
|
||||
borrow = false;
|
||||
|
||||
if(idx_low_new[i] < 0)
|
||||
{
|
||||
idx_low_new(i) += LowerLengths::At(i);
|
||||
borrow = true;
|
||||
}
|
||||
});
|
||||
|
||||
// highest dimension, no out-of-bound check
|
||||
if(borrow)
|
||||
{
|
||||
--idx_low_new(0);
|
||||
}
|
||||
}
|
||||
|
||||
return idx_low_new - idx_low_old;
|
||||
return idx_low_new - idx_low_old;
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
@@ -290,8 +378,7 @@ struct UnMerge
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
@@ -300,7 +387,10 @@ struct UnMerge
|
||||
// UpperLengths: Sequence<...>
|
||||
// Coefficients: Sequence<...>
|
||||
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
|
||||
template <typename UpperLengths, typename Coefficients>
|
||||
template <index_t LowerLength,
|
||||
typename UpperLengths,
|
||||
typename Coefficients,
|
||||
bool SkipIsValidCheck = false>
|
||||
struct Embed
|
||||
{
|
||||
static constexpr index_t nDimLow = 1;
|
||||
@@ -325,8 +415,10 @@ struct Embed
|
||||
{
|
||||
LowerIndex idx_low(Coefficients{}[nDimUp]);
|
||||
|
||||
static_for<0, nDimUp, 1>{}(
|
||||
[&](auto idim) { idx_low(0) += idx_up[idim] * Coefficients{}[idim]; });
|
||||
for(index_t i = 0; i < nDimUp; ++i)
|
||||
{
|
||||
idx_low(0) += idx_up[i] * Coefficients{}[i];
|
||||
}
|
||||
|
||||
return idx_low;
|
||||
}
|
||||
@@ -338,18 +430,55 @@ struct Embed
|
||||
{
|
||||
LowerIndex idx_low_diff{0};
|
||||
|
||||
static_for<0, nDimUp, 1>{}(
|
||||
[&](auto idim) { idx_low_diff(0) += idx_up_diff[idim] * Coefficients{}[idim]; });
|
||||
for(index_t i = 0; i < nDimUp; ++i)
|
||||
{
|
||||
idx_low_diff(0) += idx_up_diff[i] * Coefficients{}[i];
|
||||
}
|
||||
|
||||
return idx_low_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
#if 1 // debug
|
||||
if(SkipIsValidCheck)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
bool flag = true;
|
||||
|
||||
index_t ncorner = 1;
|
||||
|
||||
for(index_t idim = 0; idim < nDimUp; ++idim)
|
||||
{
|
||||
ncorner *= 2;
|
||||
}
|
||||
|
||||
// loop over each corner of the upper tensor
|
||||
for(index_t icorner = 0; icorner < ncorner; ++icorner)
|
||||
{
|
||||
// generate upper index for each corner
|
||||
auto idx_up = make_zero_array<index_t, nDimUp>();
|
||||
|
||||
index_t itmp = icorner;
|
||||
|
||||
for(index_t idim = nDimUp - 1; idim >= 0; --idim)
|
||||
{
|
||||
idx_up(idim) = itmp % 2 == 0 ? 0 : UpperLengths::At(idim) - 1;
|
||||
itmp /= 2;
|
||||
}
|
||||
|
||||
// calculate lower index
|
||||
auto idx_low = CalculateLowerIndex(idx_up);
|
||||
|
||||
// judge if lower index is valid
|
||||
flag = flag && idx_low[0] >= 0 && idx_low[0] < LowerLength;
|
||||
}
|
||||
|
||||
return flag;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -389,8 +518,7 @@ struct Vectorize
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -53,6 +53,8 @@ struct NativeTensorCoordinate
|
||||
|
||||
__host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; }
|
||||
|
||||
__host__ __device__ constexpr const Index& GetUpperIndex() const { return mIndex; }
|
||||
|
||||
__host__ __device__ constexpr const Index& GetIndex() const { return mIndex; }
|
||||
|
||||
__host__ __device__ constexpr const index_t& GetOffset() const { return mOffset; }
|
||||
@@ -98,7 +100,24 @@ struct NativeTensorCoordinate
|
||||
return tensor_desc_type::CalculateOffsetDiff(idx_diff);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsUpperIndexMappedToValidOffset() { return true; }
|
||||
// evaluated at run-time
|
||||
__host__ __device__ constexpr bool IsUpperIndexValid() const
|
||||
{
|
||||
return tensor_desc_type::IsUpperIndexValid(GetUpperIndex());
|
||||
}
|
||||
|
||||
// evaluated at run-time
|
||||
__host__ __device__ constexpr bool IsOffsetValid() const
|
||||
{
|
||||
// For native tensor, offset is valid if upper-index is valid
|
||||
return IsUpperIndexValid();
|
||||
}
|
||||
|
||||
// evaluated at compile-time
|
||||
__host__ __device__ static constexpr bool IsOffsetValidAssumingUpperIndexIsValid()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
// mIndex may be saved and updated, however, the value of some (or all) of its entries may
|
||||
@@ -206,10 +225,30 @@ struct TransformedTensorCoordinate
|
||||
return GetLowerCoordinate().CalculateOffsetDiff(idx_low_diff);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr bool IsUpperIndexMappedToValidOffset() const
|
||||
// evaluated at run-time
|
||||
__host__ __device__ constexpr bool IsUpperIndexValid() const
|
||||
{
|
||||
return tensor_desc_type::IsUpperIndexMappedToValidLowerIndex(GetIndex()) &&
|
||||
mCoordLow.IsUpperIndexMappedToValidOffset();
|
||||
return tensor_desc_type::IsUpperIndexValid(GetUpperIndex());
|
||||
}
|
||||
|
||||
// evaluted at run-time
|
||||
__host__ __device__ constexpr bool IsOffsetValid() const
|
||||
{
|
||||
return IsUpperIndexValid() && GetLowerCoordinate().IsOffsetValid();
|
||||
}
|
||||
|
||||
// most evaluatation is done at comile-time
|
||||
__host__ __device__ constexpr bool IsLowerIndexValidAssumingUpperIndexIsValid() const
|
||||
{
|
||||
return tensor_desc_type::IsLowerIndexValidAssumingUpperIndexIsValid(
|
||||
GetLowerCoordinate().GetIndex());
|
||||
}
|
||||
|
||||
// most evaluatation is done at comile-time
|
||||
__host__ __device__ constexpr bool IsOffsetValidAssumingUpperIndexIsValid() const
|
||||
{
|
||||
return IsLowerIndexValidAssumingUpperIndexIsValid() &&
|
||||
GetLowerCoordinate().IsOffsetValidAssumingUpperIndexIsValid();
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@@ -120,10 +120,17 @@ struct NativeTensorDescriptor
|
||||
return Tuple<>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
IsUpperIndexMappedToValidOffset(const Index& /* idx */)
|
||||
// a multi-index is valid if there is a corresponding point for it in the tensor
|
||||
__host__ __device__ static constexpr bool IsUpperIndexValid(const Index& idx)
|
||||
{
|
||||
return true;
|
||||
bool flag = true;
|
||||
|
||||
for(index_t i = 0; i < nDim; ++i)
|
||||
{
|
||||
flag = flag && idx[i] >= 0 && idx[i] < GetLengths()[i];
|
||||
}
|
||||
|
||||
return flag;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -467,33 +474,49 @@ struct TransformedTensorDescriptor
|
||||
}
|
||||
#endif
|
||||
|
||||
// a multi-index is valid if there is a corresponding point for it in the tensor
|
||||
__host__ __device__ constexpr bool IsUpperIndexValid(const UpperIndex& idx_up) const
|
||||
{
|
||||
bool flag = true;
|
||||
|
||||
for(index_t i = 0; i < nDimUp; ++i)
|
||||
{
|
||||
flag = flag && idx_up[i] >= 0 && idx_up[i] < GetLengths()[i];
|
||||
}
|
||||
|
||||
return flag;
|
||||
}
|
||||
|
||||
// this function is for optimization purpose, it's called by tensor coordinate
|
||||
// this function tells you: If a lower-index is valid or not, assuming upper index is valid
|
||||
__host__ __device__ static constexpr bool
|
||||
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up)
|
||||
IsLowerIndexValidAssumingUpperIndexIsValid(const LowerIndex& idx_low)
|
||||
{
|
||||
bool flag = true;
|
||||
|
||||
static_for<0, nTransform, 1>{}([&](auto itran) {
|
||||
constexpr auto tran = Transforms{}.At(itran);
|
||||
|
||||
const auto idx_up_part = pick_array_element(idx_up, UpDimensionIds{}.At(itran));
|
||||
// check a indtransformation if it does not always has a valid mapping
|
||||
constexpr bool is_valid_up_always_mapped_to_valid_low =
|
||||
decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex();
|
||||
|
||||
flag = flag && tran.IsUpperIndexMappedToValidLowerIndex(to_array(idx_up_part));
|
||||
if(!is_valid_up_always_mapped_to_valid_low)
|
||||
{
|
||||
constexpr auto low_dims_part = LowDimensionIds{}.At(itran);
|
||||
constexpr auto low_lengths_part =
|
||||
GetLowerTensorDescriptor().GetLengths(low_dims_part);
|
||||
const auto idx_low_part = to_array(pick_array_element(idx_low, low_dims_part));
|
||||
|
||||
for(index_t i = 0; i < low_dims_part.Size(); ++i)
|
||||
{
|
||||
flag = flag && idx_low_part[i] >= 0 && idx_low_part[i] < low_lengths_part[i];
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return flag;
|
||||
}
|
||||
|
||||
// Whenever this function is called, it will call CalculateLowerIndex() recursively.
|
||||
// If you have created a tensor coordinate already, instead of calling this function,
|
||||
// you should call TensorCoordinate::IsUpperIndexMappedToValidOffset() which would
|
||||
// be less expensive.
|
||||
__host__ __device__ static constexpr bool
|
||||
IsUpperIndexMappedToValidOffset(const UpperIndex& idx_up)
|
||||
{
|
||||
return IsUpperIndexMappedToValidLowerIndex(idx_up) &&
|
||||
GetLowerTensorDescriptor().IsUpperIndexMappedToValidOffset(
|
||||
CalculateLowerIndex(idx_up));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -50,9 +50,37 @@ template <index_t GridSize,
|
||||
index_t CThreadCopyDstDataPerWrite>
|
||||
struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
ThreadGemmDataPerReadM,
|
||||
ThreadGemmDataPerReadN);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr index_t a_block_space =
|
||||
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
constexpr index_t b_block_space =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_space + b_block_space) * sizeof(Float);
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_a_global,
|
||||
const Float* __restrict__ p_b_global,
|
||||
Float* __restrict__ p_c_global) const
|
||||
Float* __restrict__ p_c_global,
|
||||
Float* __restrict__ p_shared_block) const
|
||||
{
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
@@ -64,6 +92,12 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
constexpr auto M = a_k_m_global_desc.GetLengths()[1];
|
||||
constexpr auto N = b_k_n_global_desc.GetLengths()[1];
|
||||
|
||||
// don't do anything if K == 0
|
||||
if(K == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// lds max alignment
|
||||
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
@@ -184,8 +218,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
constexpr index_t b_block_space =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
__shared__ Float p_a_block_double[2 * a_block_space];
|
||||
__shared__ Float p_b_block_double[2 * b_block_space];
|
||||
Float* p_a_block_double = p_shared_block;
|
||||
Float* p_b_block_double = p_shared_block + 2 * a_block_space;
|
||||
|
||||
// register allocation for output
|
||||
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
|
||||
@@ -329,6 +363,17 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
.Run(p_c_thread, p_c_global);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_a_global,
|
||||
const Float* __restrict__ p_b_global,
|
||||
Float* __restrict__ p_c_global) const
|
||||
{
|
||||
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
|
||||
|
||||
__shared__ Float p_shared_block[shared_block_size];
|
||||
|
||||
Run(p_a_global, p_b_global, p_c_global, p_shared_block);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -110,7 +110,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// Check src data's valid mapping situation, only check the first data in this src
|
||||
// vector. It's user's responsiblity to make sure all data in the src vector
|
||||
// has the valid/invalid mapping situation
|
||||
if(src_coord.IsUpperIndexMappedToValidOffset())
|
||||
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
@@ -142,7 +142,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// Check dst data's valid mapping situation, only check the first data in this dst
|
||||
// vector. It's user's responsiblity to make sure all data in the dst vector
|
||||
// has the valid/invalid mapping situation
|
||||
if(dst_coord.IsUpperIndexMappedToValidOffset())
|
||||
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<DstData,
|
||||
DstDataPerWrite,
|
||||
@@ -260,7 +260,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// src
|
||||
// vector. It's user's responsiblity to make sure all data in the src vector
|
||||
// has the valid/invalid mapping situation
|
||||
if(src_coord.IsUpperIndexMappedToValidOffset())
|
||||
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
@@ -299,7 +299,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// dst
|
||||
// vector. It's user's responsiblity to make sure all data in the dst vector
|
||||
// has the valid/invalid mapping situation
|
||||
if(dst_coord.IsUpperIndexMappedToValidOffset())
|
||||
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<DstData,
|
||||
DstDataPerWrite,
|
||||
@@ -399,7 +399,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// src
|
||||
// vector. It's user's responsiblity to make sure all data in the src vector
|
||||
// has the valid/invalid mapping situation
|
||||
if(src_coord.IsUpperIndexMappedToValidOffset())
|
||||
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
@@ -444,7 +444,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// dst
|
||||
// vector. It's user's responsiblity to make sure all data in the dst vector
|
||||
// has the valid/invalid mapping situation
|
||||
if(dst_coord.IsUpperIndexMappedToValidOffset())
|
||||
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<DstData,
|
||||
DstDataPerWrite,
|
||||
|
||||
@@ -54,6 +54,13 @@ __device__ void __llvm_amdgcn_buffer_storex4(float4_t vdata,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store.v4f32");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_buffer_atomic_add(float vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.atomic.fadd.f32");
|
||||
|
||||
// buffer_load requires:
|
||||
// 1) p_src must be in global memory space, d_dst must be vgpr
|
||||
// 2) p_src to be a block-invariant pointer.
|
||||
@@ -73,6 +80,13 @@ amd_intrinsic_buffer_store(const typename vector_type<T, VectorSize>::MemoryType
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset);
|
||||
|
||||
template <typename T, index_t VectorSize>
|
||||
__device__ void
|
||||
amd_intrinsic_buffer_atomic_add(const typename vector_type<T, VectorSize>::MemoryType& src,
|
||||
T* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset);
|
||||
|
||||
template <>
|
||||
__device__ float amd_intrinsic_buffer_load<float, 1>(const float* p_src_block,
|
||||
index_t src_thread_data_offset,
|
||||
@@ -289,5 +303,31 @@ __device__ void amd_intrinsic_buffer_store<float, 4>(const float4_t& src,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_intrinsic_buffer_atomic_add<float, 1>(const float& src,
|
||||
float* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
|
||||
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
|
||||
|
||||
BufferLoadStoreDwordConfig<float> dst_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
dst_block_config.address[0] = p_dst_block;
|
||||
// fill in byte 2
|
||||
dst_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
dst_block_config.range[3] = 0x00027000;
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC
|
||||
__llvm_amdgcn_buffer_atomic_add(
|
||||
src, dst_block_config.data, 0, dst_thread_addr_offset + dst_const_addr_offset, false);
|
||||
#else
|
||||
static_assert(false, " wrong! not implemented");
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -29,6 +29,11 @@
|
||||
#define CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC 1
|
||||
#endif
|
||||
|
||||
// only support gfx908
|
||||
#ifndef CK_USE_AMD_BUFFER_ATOMIC_ADD
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_ADD 0
|
||||
#endif
|
||||
|
||||
// AMD XDLOPS
|
||||
#ifndef CK_USE_AMD_XDLOPS
|
||||
#define CK_USE_AMD_XDLOPS 0
|
||||
|
||||
@@ -29,9 +29,10 @@ struct static_for
|
||||
{
|
||||
__host__ __device__ constexpr static_for()
|
||||
{
|
||||
static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd");
|
||||
static_assert((NEnd - NBegin) % Increment == 0,
|
||||
static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0,
|
||||
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
|
||||
static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd),
|
||||
"wrongs! should have NBegin <= NEnd");
|
||||
}
|
||||
|
||||
template <class F>
|
||||
|
||||
@@ -52,8 +52,13 @@ __device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, in
|
||||
|
||||
static_if<SrcAddressSpace == AddressSpace::vgpr && DstAddressSpace == AddressSpace::global>{}(
|
||||
[&](auto) {
|
||||
#if CK_USE_AMD_BUFFER_ATOMIC_ADD
|
||||
amd_intrinsic_buffer_atomic_add<T, DataPerAccess>(
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]), p_dst, dst_offset, 0);
|
||||
#else
|
||||
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
|
||||
#endif
|
||||
})
|
||||
.Else([&](auto fwd) {
|
||||
static_assert(fwd(false), "atomic_add doesn't support this memory space");
|
||||
|
||||
@@ -49,6 +49,12 @@ struct integer_divide_ceiler
|
||||
}
|
||||
};
|
||||
|
||||
template <class X, class Y>
|
||||
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
|
||||
{
|
||||
return x / y;
|
||||
}
|
||||
|
||||
template <class X, class Y>
|
||||
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user