mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
Code clean up (#20)
* tuning para, * testing on v100 * add fp16 * remove deprecated tensor descriptor * sync with miopen * update build script Co-authored-by: Jing Zhang <jizhan@amd.com>
This commit is contained in:
@@ -1,12 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_KERNEL_WRAPPER
|
||||
#define CK_GRIDWISE_CONVOLUTION_KERNEL_WRAPPER
|
||||
|
||||
template <class GridwiseConvolution, class T>
|
||||
__global__ void run_gridwise_convolution_kernel(const T* const __restrict__ p_in_global,
|
||||
const T* const __restrict__ p_wei_global,
|
||||
T* const __restrict__ p_out_global)
|
||||
{
|
||||
GridwiseConvolution{}.Run(p_in_global, p_wei_global, p_out_global);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -2,7 +2,7 @@
|
||||
#define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER
|
||||
|
||||
template <typename GridwiseOp, typename... Xs>
|
||||
__global__ void run_gridwise_operation(GridwiseOp, Xs... xs)
|
||||
__global__ void run_gridwise_operation(Xs... xs)
|
||||
{
|
||||
GridwiseOp{}.Run(xs...);
|
||||
}
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
#ifndef CK_CONVOLUTION_COMMON_HPP
|
||||
#define CK_CONVOLUTION_COMMON_HPP
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum ConvolutionDirection
|
||||
{
|
||||
Forward,
|
||||
BackwardData,
|
||||
BackwardWeight
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,130 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_COL2IM_EB_NCHW_HPP
|
||||
#define CK_GRIDWISE_COL2IM_EB_NCHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// B = merge(N, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename ColGlobalDesc,
|
||||
typename ImgGlobalDesc,
|
||||
typename FilterSizes,
|
||||
typename OutputSizes,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
index_t EPerBlock,
|
||||
index_t BPerBlock,
|
||||
typename BlockCopySubLengths_E_B,
|
||||
typename BlockCopyClusterLengths_E_B,
|
||||
typename BlockCopyThreadClusterArrangeOrder,
|
||||
typename BlockCopySrcAccessOrder,
|
||||
typename BlockCopyDstAccessOrder,
|
||||
index_t BlockCopyDataPerAccess_B>
|
||||
struct GridwiseCol2Im_eb_nchw
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_col_global,
|
||||
Float* const __restrict__ p_img_global) const
|
||||
{
|
||||
constexpr auto col_e_b_global_desc = ColGlobalDesc{};
|
||||
constexpr auto img_n_c_hi_wi_global_desc = ImgGlobalDesc{};
|
||||
|
||||
constexpr index_t N = img_n_c_hi_wi_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = img_n_c_hi_wi_global_desc.GetLengths()[1];
|
||||
constexpr index_t Hi = img_n_c_hi_wi_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = img_n_c_hi_wi_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Ho = OutputSizes{}[0];
|
||||
constexpr index_t Wo = OutputSizes{}[1];
|
||||
|
||||
constexpr index_t Y = FilterSizes{}[0];
|
||||
constexpr index_t X = FilterSizes{}[1];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t E = C * Y * X;
|
||||
constexpr index_t B = N * Ho * Wo;
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Wo == 1 || (ConvStrideW == 1 || BlockCopyDataPerAccess_B == 1)) &&
|
||||
(X == 1 || ConvDilationW % BlockCopyDataPerAccess_B == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// divide block work by [E, B]
|
||||
static_assert(E % EPerBlock == 0 && B % BPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t EBlockWork = E / EPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_cluster_descriptor(Sequence<EBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
|
||||
|
||||
const index_t e_block_data_on_global = block_work_id[0] * EPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
|
||||
|
||||
// construct img_eb_global_desc
|
||||
constexpr auto img_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
img_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto img_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
img_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto img_e_b_global_desc = transform_tensor_descriptor(
|
||||
img_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// blockwise atomic accumulation
|
||||
auto blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(col_e_b_global_desc),
|
||||
decltype(img_e_b_global_desc),
|
||||
Sequence<EPerBlock, BPerBlock>,
|
||||
BlockCopySubLengths_E_B,
|
||||
BlockCopyClusterLengths_E_B,
|
||||
BlockCopyThreadClusterArrangeOrder,
|
||||
BlockCopySrcAccessOrder,
|
||||
BlockCopyDstAccessOrder,
|
||||
1,
|
||||
1,
|
||||
BlockCopyDataPerAccess_B,
|
||||
BlockCopyDataPerAccess_B,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Global,
|
||||
InMemoryDataOperation::AtomicAdd>(
|
||||
{e_block_data_on_global, b_block_data_on_global},
|
||||
{e_block_data_on_global, b_block_data_on_global});
|
||||
|
||||
// blockwise copy
|
||||
blockwise_copy.Run(p_col_global, p_img_global);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -36,8 +36,8 @@ template <index_t GridSize,
|
||||
index_t ThreadGemmDataPerRead_GemmN,
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmN,
|
||||
index_t GemmABlockCopyDstDataPerWrite_GemmN,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmM,
|
||||
index_t GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
index_t GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
@@ -82,13 +82,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
constexpr auto wei_gemmk_gemmm_global_desc =
|
||||
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3);
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_gemmk_gemmn_global_desc =
|
||||
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// input tensor
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
@@ -98,16 +91,15 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Hi + InLeftPads::At(0) + InRightPads::At(0),
|
||||
Sequence<Y, Ho>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Wi + InLeftPads::At(1) + InRightPads::At(1),
|
||||
Sequence<X, Wo>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
@@ -117,6 +109,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_gemmk_gemmn_global_desc =
|
||||
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// GEMM
|
||||
// \todo there are more combinations of Y, ConvDilationH and ConvStrideH that don't need
|
||||
// atomic, find out all of them
|
||||
@@ -152,8 +151,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmABlockCopySrcDataPerRead_GemmN,
|
||||
GemmABlockCopyDstDataPerWrite_GemmN,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<0, 1>,
|
||||
|
||||
@@ -25,13 +25,13 @@ template <index_t GridSize,
|
||||
index_t EPerBlock,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMPerThread,
|
||||
index_t GemmNPerThread,
|
||||
index_t GemmKPerThread,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
typename OutBlockCopySubLengths_K_B_N0,
|
||||
@@ -78,8 +78,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t C0 = GemmMPerThreadSubC;
|
||||
constexpr index_t N0 = GemmNPerThreadSubC;
|
||||
constexpr index_t C0 = GemmMPerThread;
|
||||
constexpr index_t N0 = GemmNPerThread;
|
||||
|
||||
static_assert(C % C0 == 0 && N % N0 == 0, "wrong!");
|
||||
|
||||
@@ -225,20 +225,20 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_e0e1c0_b0b1n0_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
|
||||
Number<GemmMRepeat * GemmMPerThread>{}, Number<GemmNRepeat * GemmNPerThread>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_k_ec0_block_mtx_desc),
|
||||
decltype(b_k_bn0_block_mtx_desc),
|
||||
decltype(c_e0e1c0_b0b1n0_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
@@ -371,7 +371,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
|
||||
// define input tensor descriptor for threadwise copy
|
||||
// thread input tensor, src of threadwise copy
|
||||
constexpr auto in_e0_e1_c0_b0_b1_n0_thread_desc = make_native_tensor_descriptor_packed(
|
||||
Sequence<GemmMRepeat, 1, GemmMPerThreadSubC, GemmNRepeat, 1, GemmNPerThreadSubC>{});
|
||||
Sequence<GemmMRepeat, 1, GemmMPerThread, GemmNRepeat, 1, GemmNPerThread>{});
|
||||
|
||||
// global input tensor, dst of threadwise copy
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
@@ -419,10 +419,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t e_thread_data_on_global =
|
||||
e_block_data_on_global + c_thread_mtx_on_block.row / GemmMPerThreadSubC;
|
||||
e_block_data_on_global + c_thread_mtx_on_block.row / GemmMPerThread;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / GemmNPerThreadSubC;
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / GemmNPerThread;
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<
|
||||
decltype(in_e0_e1_c0_b0_b1_n0_thread_desc),
|
||||
|
||||
@@ -419,7 +419,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
template <index_t GemmId>
|
||||
__device__ static void Run(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global)
|
||||
const Float* __restrict__ p_out_global,
|
||||
Number<GemmId>)
|
||||
{
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
@@ -1,255 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_DIRECT_V2_NCHW_KCYX_NKHW
|
||||
#define CK_GRIDWISE_CONVOLUTION_DIRECT_V2_NCHW_KCYX_NKHW
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
#include "threadwise_direct_convolution.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t CPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead>
|
||||
struct GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_nchw_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_kcyx_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_nkhw_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_nchw_global_desc.GetLength(I0);
|
||||
constexpr index_t K = wei_kcyx_global_desc.GetLength(I0);
|
||||
constexpr index_t C = wei_kcyx_global_desc.GetLength(I1);
|
||||
constexpr index_t Y = wei_kcyx_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_kcyx_global_desc.GetLength(I3);
|
||||
|
||||
constexpr auto wei_ke_global_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<K, C * Y * X>{}); // 2d view of wei for blockwise copy
|
||||
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
constexpr auto in_nchw_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{},
|
||||
Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr auto wei_ke_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<KPerBlock, CPerBlock * Y * X>{},
|
||||
Number<WeiBlockCopyDataPerRead>{}); // 2d view of wei for blockwise copy
|
||||
|
||||
constexpr auto wei_kcyx_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<KPerBlock, CPerBlock, Y, X>{},
|
||||
Sequence<wei_ke_block_desc.GetStride(I0), Y * X, X, 1>{});
|
||||
|
||||
// shared mem
|
||||
constexpr index_t in_block_element_size =
|
||||
in_nchw_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
|
||||
constexpr index_t wei_block_element_size =
|
||||
wei_kcyx_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
|
||||
? InBlockCopyDataPerRead
|
||||
: WeiBlockCopyDataPerRead;
|
||||
|
||||
__shared__ Float
|
||||
p_in_block[max_align * ((in_block_element_size + max_align - 1) / max_align)];
|
||||
__shared__ Float
|
||||
p_wei_block[max_align * ((wei_block_element_size + max_align - 1) / max_align)];
|
||||
|
||||
// threadwise tensors
|
||||
constexpr index_t HiPerThread = HoPerThread + Y - 1;
|
||||
constexpr index_t WiPerThread = WoPerThread + X - 1;
|
||||
|
||||
constexpr auto in_nchw_thread_block_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{},
|
||||
in_nchw_block_desc.GetStrides());
|
||||
|
||||
constexpr auto wei_kcyx_thread_block_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, CPerThread, Y, X>{}, wei_kcyx_block_desc.GetStrides());
|
||||
|
||||
constexpr auto out_nkhw_thread_desc =
|
||||
get_convolution_output_default_4d_tensor_descriptor_deprecated(
|
||||
in_nchw_thread_block_desc, wei_kcyx_thread_block_desc);
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_nkhw_thread_desc.GetElementSpace()];
|
||||
|
||||
// divide block work
|
||||
constexpr index_t NBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
|
||||
constexpr index_t KBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
|
||||
|
||||
const index_t block_id = blockIdx.x;
|
||||
|
||||
index_t itmp = block_id;
|
||||
const index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
|
||||
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
|
||||
const index_t k_block_work_id = itmp / (HBlockWork * WBlockWork);
|
||||
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
|
||||
const index_t h_block_work_id = itmp / WBlockWork;
|
||||
const index_t w_block_work_id = itmp - h_block_work_id * WBlockWork;
|
||||
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin; // minus padding
|
||||
const index_t wi_block_data_begin = wo_block_data_begin; // minus padding
|
||||
|
||||
// divide thread work
|
||||
constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
|
||||
constexpr index_t KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread;
|
||||
constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
|
||||
constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
|
||||
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
|
||||
itmp = thread_id;
|
||||
const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
|
||||
itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork);
|
||||
const index_t k_thread_work_id = itmp / (HThreadWork * WThreadWork);
|
||||
itmp -= k_thread_work_id * (HThreadWork * WThreadWork);
|
||||
const index_t h_thread_work_id = itmp / WThreadWork;
|
||||
const index_t w_thread_work_id = itmp - h_thread_work_id * WThreadWork;
|
||||
|
||||
const index_t n_thread_data_begin = n_thread_work_id * NPerThread;
|
||||
const index_t k_thread_data_begin = k_thread_work_id * KPerThread;
|
||||
const index_t ho_thread_data_begin = h_thread_work_id * HoPerThread;
|
||||
const index_t wo_thread_data_begin = w_thread_work_id * WoPerThread;
|
||||
|
||||
const index_t hi_thread_data_begin = ho_thread_data_begin;
|
||||
const index_t wi_thread_data_begin = wo_thread_data_begin;
|
||||
|
||||
constexpr auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_nchw_global_desc),
|
||||
decltype(in_nchw_block_desc),
|
||||
decltype(in_nchw_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead>{};
|
||||
|
||||
#if 0
|
||||
constexpr auto blockwise_wei_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_kcyx_global_desc),
|
||||
decltype(wei_kcyx_block_desc),
|
||||
decltype(wei_kcyx_block_desc.GetLengths()),
|
||||
1>{};
|
||||
#elif 1
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ke_global_desc),
|
||||
decltype(wei_ke_block_desc),
|
||||
decltype(wei_ke_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>({0, 0}, {0, 0});
|
||||
#endif
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
// copy input tensor to LDS
|
||||
blockwise_in_copy.Run(
|
||||
p_in_global +
|
||||
in_nchw_global_desc.GetOffsetFromMultiIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
p_in_block);
|
||||
|
||||
// copy weight tensor to LDS
|
||||
blockwise_wei_copy.Run(p_wei_global +
|
||||
wei_kcyx_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(index_t c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
|
||||
{
|
||||
// threadwise convolution
|
||||
#if 1
|
||||
threadwise_direct_convolution_2(
|
||||
in_nchw_thread_block_desc,
|
||||
p_in_block +
|
||||
in_nchw_block_desc.GetOffsetFromMultiIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_thread_block_desc,
|
||||
p_wei_block +
|
||||
wei_kcyx_block_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_begin, c_thread_data, 0, 0),
|
||||
out_nkhw_thread_desc,
|
||||
p_out_thread);
|
||||
#elif 0
|
||||
threadwise_direct_convolution_3(
|
||||
in_nchw_thread_block_desc,
|
||||
p_in_block +
|
||||
in_nchw_block_desc.GetOffsetFromMultiIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_thread_block_desc,
|
||||
p_wei_block +
|
||||
wei_kcyx_block_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_begin, c_thread_data, 0, 0),
|
||||
out_nkhw_thread_desc,
|
||||
p_out_thread);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// copy output tensor from register to global mem
|
||||
threadwise_tensor_slice_copy(out_nkhw_thread_desc,
|
||||
p_out_thread,
|
||||
out_nkhw_global_desc,
|
||||
p_out_global +
|
||||
out_nkhw_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_nkhw_thread_desc.GetLengths(),
|
||||
Number<1>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,398 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R1_CHWN_CYXK_KHWN
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R1_CHWN_CYXK_KHWN
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
#include "threadwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_batched_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopyClusterLengths_CHWN,
|
||||
index_t InBlockCopyDataPerRead_N,
|
||||
index_t WeiBlockCopyDataPerRead_K,
|
||||
index_t OutThreadCopyDataPerWrite_N>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
|
||||
|
||||
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
|
||||
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
|
||||
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
|
||||
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
|
||||
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
|
||||
const index_t w_block_work_id = itmp / NBlockWork;
|
||||
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
|
||||
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// flattend (2d) tensor view of gridwise weight
|
||||
constexpr auto wei_cyx_k_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
|
||||
|
||||
// tensor view of blockwise input and weight in LDS
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDataPerRead_N,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{},
|
||||
Number<InBlockCopyDataPerRead_N>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_cyx_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock * Y * X, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
constexpr auto wei_c_y_x_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, Y, X, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
const auto blockwise_in_copy =
|
||||
#if 0
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead_N>{};
|
||||
#else
|
||||
Blockwise4dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
InBlockCopyDataPerRead_N>{};
|
||||
#endif
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock*Y*X,KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_cyx_k_global_desc),
|
||||
decltype(wei_cyx_k_block_desc),
|
||||
decltype(wei_cyx_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead_K>{};
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<KPerBlock>{},
|
||||
Number<wei_c_y_x_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_k_wn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_wn_block_mtx_desc),
|
||||
decltype(c_k_wn_thread_mtx_desc),
|
||||
0,
|
||||
in_c_h_w_n_block_desc.GetStride(I1),
|
||||
out_k_h_w_n_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space =
|
||||
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
wei_c_y_x_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0),
|
||||
__syncthreads())
|
||||
{
|
||||
#if 1
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
|
||||
#else
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
|
||||
p_wei_register_buffer);
|
||||
|
||||
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block);
|
||||
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block);
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
#if 1
|
||||
blockwise_batch_gemm.Run
|
||||
#else
|
||||
blockwise_batch_gemm.Run_amd_asm
|
||||
#endif
|
||||
(p_wei_block + wei_c_y_x_k_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_in_block + in_c_h_w_n_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
|
||||
// perfect forwarding.
|
||||
// Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert(
|
||||
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_tensor_slice_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
}).Else([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
|
||||
for(index_t i = 0; i < 64; ++i)
|
||||
{
|
||||
printf("out %f, ", p_out_thread[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_tensor_slice_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,435 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R2_CHWN_CYXK_KHWN
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R2_CHWN_CYXK_KHWN
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_3d_tensor_op.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
#include "threadwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_batched_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopyClusterLengths_CHWN,
|
||||
index_t InBlockCopyDataPerRead_N,
|
||||
index_t WeiBlockCopyDataPerRead_K,
|
||||
index_t OutThreadCopyDataPerWrite_N>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
|
||||
|
||||
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
|
||||
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
|
||||
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
|
||||
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
|
||||
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
|
||||
const index_t w_block_work_id = itmp / NBlockWork;
|
||||
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
|
||||
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// global tensor view
|
||||
constexpr auto wei_c_x_k_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<C, X, K>{}, Sequence<Y * X * K, K, 1>{});
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDataPerRead_N,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WiPerBlock, NPerBlock>{},
|
||||
Number<InBlockCopyDataPerRead_N>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_c_x_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, X, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
#if 1
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead_N>{};
|
||||
#else
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
InBlockCopyDataPerRead_N>{};
|
||||
#endif
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, X * KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise3dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_x_k_global_desc),
|
||||
decltype(wei_c_x_k_block_desc),
|
||||
decltype(wei_c_x_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead_K>{};
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_x_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_k_wn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_wn_block_mtx_desc),
|
||||
decltype(c_k_wn_thread_mtx_desc),
|
||||
0,
|
||||
in_c_h_w_n_block_desc.GetStride(I1),
|
||||
out_k_h_w_n_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space =
|
||||
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
constexpr index_t wei_block_space =
|
||||
wei_c_x_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc");
|
||||
|
||||
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
|
||||
}
|
||||
#endif
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
|
||||
#if 1
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
blockwise_in_copy.Run(
|
||||
p_in_global_block_offset +
|
||||
in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(0, y, 0, 0),
|
||||
p_in_block);
|
||||
|
||||
blockwise_wei_copy.Run(
|
||||
p_wei_global_block_offset +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, 0, 0),
|
||||
p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
blockwise_batch_gemm.Run(
|
||||
p_wei_block + wei_c_x_k_block_desc.GetOffsetFromMultiIndex(0, x, 0),
|
||||
p_in_block + in_c_h_w_n_block_desc.GetOffsetFromMultiIndex(0, 0, x, 0),
|
||||
p_out_thread);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
#else
|
||||
// this use much more register, haven't figure out why?
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
0, hi_block_data_begin + y, wi_block_data_begin, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, 0, k_block_data_begin);
|
||||
|
||||
for(index_t
|
||||
c_block_data_begin = 0;
|
||||
c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
|
||||
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
blockwise_batch_gemm.Run(
|
||||
p_wei_block + wei_c_x_k_block_desc.GetOffsetFromMultiIndex(0, x, 0),
|
||||
p_in_block + in_c_h_w_n_block_desc.GetOffsetFromMultiIndex(0, 0, x, 0),
|
||||
p_out_thread);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
|
||||
// perfect forwarding.
|
||||
// Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert(
|
||||
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_tensor_slice_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
}).Else([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
|
||||
for(index_t i = 0; i < 64; ++i)
|
||||
{
|
||||
printf("out %f, ", p_out_thread[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_tensor_slice_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,420 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_batched_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopySubLengths_CHWN,
|
||||
class InBlockCopyClusterLengths_CHWN,
|
||||
index_t InBlockCopyDataPerAccess_N,
|
||||
class WeiBlockCopySubLengths_CK,
|
||||
class WeiBlockCopyClusterLengths_CK,
|
||||
index_t WeiBlockCopyDataPerAccess_K,
|
||||
index_t OutThreadCopyDataPerAccess_N>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
|
||||
|
||||
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock);
|
||||
constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock);
|
||||
constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock);
|
||||
constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KBlockWork, HBlockWork, WBlockWork, NBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_begin = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t ho_block_data_begin = block_work_multi_id[1] * HoPerBlock;
|
||||
const index_t wo_block_data_begin = block_work_multi_id[2] * WoPerBlock;
|
||||
const index_t n_block_data_begin = block_work_multi_id[3] * NPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// global tensor view
|
||||
constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3);
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_N,
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{}, Number<max_align>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, KPerBlock>{}, Number<max_align>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
|
||||
BlockSize,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_CHWN,
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
3,
|
||||
InBlockCopyDataPerAccess_N,
|
||||
InBlockCopyDataPerAccess_N>({0, 0, 0, 0}, {0, 0, 0, 0});
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, X * KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1_deprecated<BlockSize,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_CK,
|
||||
WeiBlockCopyClusterLengths_CK,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
WeiBlockCopyDataPerAccess_K>({0, 0},
|
||||
{0, 0});
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_k_wn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_wn_block_mtx_desc),
|
||||
decltype(c_k_wn_thread_mtx_desc),
|
||||
0,
|
||||
in_c_h_w_n_block_desc.GetStride(I1),
|
||||
out_k_h_w_n_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space = in_c_h_w_n_block_desc.GetElementSpace();
|
||||
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace();
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc");
|
||||
|
||||
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
|
||||
}
|
||||
#endif
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_matrix_set_zero(c_k_wn_thread_mtx_desc, p_out_thread);
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
0, hi_block_data_begin + y, wi_block_data_begin + x, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset +=
|
||||
CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset +=
|
||||
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
|
||||
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto fwd) {
|
||||
// fwd do nothing but perfect forwarding.
|
||||
// Using this trick to make this lambda a generic lambda, so it won't be compiled until
|
||||
// being instantiated here
|
||||
static_assert(
|
||||
(fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc = fwd(out_k_h_w_n_global_desc)
|
||||
.Fold(I3, Number<N1>{}, Number<N2>{})
|
||||
.Fold(I2, Number<W1>{}, Number<W2>{})
|
||||
.Fold(I0, Number<K1>{}, Number<K2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc)
|
||||
.Fold(I3, Number<1>{}, Number<N2>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"a: out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"a: out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
Float* p_out_thread_on_global = p_out_global +
|
||||
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin);
|
||||
|
||||
#if 1
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
|
||||
decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(make_zero_array<index_t, 10>(),
|
||||
make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#elif 0
|
||||
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(
|
||||
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#endif
|
||||
}).Else([&](auto fwd) {
|
||||
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / fwd(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
fwd(out_k_h_w_n_global_desc)
|
||||
.Fold(I3, Number<N1>{})
|
||||
.Fold(I2, Number<W1>{}, Number<W2>{}, Number<W3>{})
|
||||
.Fold(I0, Number<K1>{}, Number<K2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc =
|
||||
fwd(out_k_h_w_n_thread_desc)
|
||||
.Fold(I3, Number<N1>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{}, Number<W3>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"b: out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"b: out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
Float* p_out_thread_on_global = p_out_global +
|
||||
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin);
|
||||
|
||||
#if 1
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
|
||||
decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(make_zero_array<index_t, 10>(),
|
||||
make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#elif 0
|
||||
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(
|
||||
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#endif
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,508 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_batched_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopySubLengths_CHWN,
|
||||
class InBlockCopyClusterLengths_CHWN,
|
||||
index_t InBlockCopyDataPerAccess_N,
|
||||
class WeiBlockCopySubLengths_CK,
|
||||
class WeiBlockCopyClusterLengths_CK,
|
||||
index_t WeiBlockCopyDataPerAccess_K,
|
||||
index_t OutThreadCopyDataPerAccess_N>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
|
||||
|
||||
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % (2 * CPerBlock) == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock);
|
||||
constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock);
|
||||
constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock);
|
||||
constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KBlockWork, HBlockWork, WBlockWork, NBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_begin = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t ho_block_data_begin = block_work_multi_id[1] * HoPerBlock;
|
||||
const index_t wo_block_data_begin = block_work_multi_id[2] * WoPerBlock;
|
||||
const index_t n_block_data_begin = block_work_multi_id[3] * NPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// global tensor view
|
||||
constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3);
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_N,
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{}, Number<max_align>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, KPerBlock>{}, Number<max_align>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
auto blockwise_in_copy =
|
||||
#if 0
|
||||
BlockwiseGenericTensorSliceCopy_v1_deprecated
|
||||
#else
|
||||
BlockwiseGenericTensorSliceCopy_v2_deprecated
|
||||
#endif
|
||||
<BlockSize,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_CHWN,
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
3,
|
||||
InBlockCopyDataPerAccess_N,
|
||||
InBlockCopyDataPerAccess_N>({0, 0, 0, 0}, {0, 0, 0, 0});
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, X * KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
#if 0
|
||||
BlockwiseGenericTensorSliceCopy_v1_deprecated
|
||||
#else
|
||||
BlockwiseGenericTensorSliceCopy_v2_deprecated
|
||||
#endif
|
||||
<BlockSize,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_CK,
|
||||
WeiBlockCopyClusterLengths_CK,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0});
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_k_wn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_wn_block_mtx_desc),
|
||||
decltype(c_k_wn_thread_mtx_desc),
|
||||
0,
|
||||
in_c_h_w_n_block_desc.GetStride(I1),
|
||||
out_k_h_w_n_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space = in_c_h_w_n_block_desc.GetElementSpace();
|
||||
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace();
|
||||
|
||||
// LDS double buffer
|
||||
__shared__ Float p_in_block_double[2 * in_block_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
// register
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc");
|
||||
|
||||
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
|
||||
}
|
||||
#endif
|
||||
|
||||
// set threadwise output to 0
|
||||
threadwise_matrix_set_zero(c_k_wn_thread_mtx_desc, p_out_thread);
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
0, hi_block_data_begin + y, wi_block_data_begin + x, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin);
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset,
|
||||
p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
|
||||
p_wei_register_buffer);
|
||||
|
||||
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
|
||||
p_in_block_double);
|
||||
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
|
||||
p_wei_block_double);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C;
|
||||
c_block_data_begin += 2 * CPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_in_block_now =
|
||||
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
|
||||
|
||||
Float* p_in_block_next =
|
||||
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
p_in_global_block_offset +=
|
||||
CPerBlock * in_c_h_w_n_global_desc.GetStride(I0);
|
||||
p_wei_global_block_offset +=
|
||||
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset,
|
||||
p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
|
||||
p_wei_register_buffer);
|
||||
|
||||
blockwise_batch_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
|
||||
p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
|
||||
p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
// even iteration
|
||||
p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0);
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset,
|
||||
p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
|
||||
p_wei_register_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_batch_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
// odd iteration
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_batch_gemm.Run(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto fwd) {
|
||||
// fwd do nothing but perfect forwarding.
|
||||
// Using this trick to make this lambda a generic lambda, so it won't be compiled until
|
||||
// being instantiated here
|
||||
static_assert(
|
||||
(fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc = fwd(out_k_h_w_n_global_desc)
|
||||
.Fold(I3, Number<N1>{}, Number<N2>{})
|
||||
.Fold(I2, Number<W1>{}, Number<W2>{})
|
||||
.Fold(I0, Number<K1>{}, Number<K2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc)
|
||||
.Fold(I3, Number<1>{}, Number<N2>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"a: out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"a: out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
Float* p_out_thread_on_global = p_out_global +
|
||||
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin);
|
||||
|
||||
#if 1
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
|
||||
decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(make_zero_array<index_t, 10>(),
|
||||
make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#elif 0
|
||||
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(
|
||||
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#endif
|
||||
}).Else([&](auto fwd) {
|
||||
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / fwd(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
fwd(out_k_h_w_n_global_desc)
|
||||
.Fold(I3, Number<N1>{})
|
||||
.Fold(I2, Number<W1>{}, Number<W2>{}, Number<W3>{})
|
||||
.Fold(I0, Number<K1>{}, Number<K2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc =
|
||||
fwd(out_k_h_w_n_thread_desc)
|
||||
.Fold(I3, Number<N1>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{}, Number<W3>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"b: out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"b: out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
Float* p_out_thread_on_global = p_out_global +
|
||||
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin);
|
||||
|
||||
#if 1
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
|
||||
decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(make_zero_array<index_t, 10>(),
|
||||
make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#elif 0
|
||||
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(
|
||||
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#endif
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,414 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_PADDED_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_PADDED_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_batched_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
class LeftPads,
|
||||
class RightPads,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopySubLengths_CHWN,
|
||||
class InBlockCopyClusterLengths_CHWN,
|
||||
index_t InBlockCopyDataPerAccess_N,
|
||||
class WeiBlockCopySubLengths_CK,
|
||||
class WeiBlockCopyClusterLengths_CK,
|
||||
index_t WeiBlockCopyDataPerAccess_K,
|
||||
index_t OutThreadCopyDataPerAccess_N>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto True = integral_constant<bool, true>{};
|
||||
static constexpr auto False = integral_constant<bool, false>{};
|
||||
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto in_c_h_w_n_global_desc_old = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc_old = WeiGlobalDesc{};
|
||||
constexpr auto out_k_h_w_n_global_desc_old = OutGlobalDesc{};
|
||||
|
||||
constexpr auto in_c_h_w_n_global_desc = make_native_tensor_descriptor(
|
||||
in_c_h_w_n_global_desc_old.GetLengths(), in_c_h_w_n_global_desc_old.GetStrides());
|
||||
|
||||
constexpr auto wei_c_y_x_k_global_desc = make_native_tensor_descriptor(
|
||||
wei_c_y_x_k_global_desc_old.GetLengths(), wei_c_y_x_k_global_desc_old.GetStrides());
|
||||
|
||||
constexpr auto out_k_h_w_n_global_desc = make_native_tensor_descriptor(
|
||||
out_k_h_w_n_global_desc_old.GetLengths(), out_k_h_w_n_global_desc_old.GetStrides());
|
||||
|
||||
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
|
||||
constexpr index_t Hi = in_c_h_w_n_global_desc.GetLength(I1);
|
||||
constexpr index_t Wi = in_c_h_w_n_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock);
|
||||
constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock);
|
||||
constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock);
|
||||
constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KBlockWork, HBlockWork, WBlockWork, NBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_begin = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t ho_block_data_begin = block_work_multi_id[1] * HoPerBlock;
|
||||
const index_t wo_block_data_begin = block_work_multi_id[2] * WoPerBlock;
|
||||
const index_t n_block_data_begin = block_work_multi_id[3] * NPerBlock;
|
||||
|
||||
const index_t hp_block_data_begin = ho_block_data_begin;
|
||||
const index_t wp_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// input global tensor view
|
||||
constexpr auto in_c_hp_wp_n_global_desc = transform_tensor_descriptor(
|
||||
in_c_h_w_n_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}, PassThrough<N>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_N,
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc_old = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{}, Number<max_align>{});
|
||||
|
||||
// hack
|
||||
constexpr auto in_c_h_w_n_block_desc = make_native_tensor_descriptor(
|
||||
in_c_h_w_n_block_desc_old.GetLengths(), in_c_h_w_n_block_desc_old.GetStrides());
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_c_1_1_k_block_desc_old = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, 1, 1, KPerBlock>{}, Number<max_align>{});
|
||||
|
||||
constexpr auto wei_c_1_1_k_block_desc = make_native_tensor_descriptor(
|
||||
wei_c_1_1_k_block_desc_old.GetLengths(), wei_c_1_1_k_block_desc_old.GetStrides());
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space = in_c_h_w_n_block_desc_old.GetElementSpace();
|
||||
constexpr index_t wei_block_space = wei_c_1_1_k_block_desc_old.GetElementSpace();
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc_old = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_native_tensor_descriptor(
|
||||
out_k_h_w_n_thread_desc_old.GetLengths(), out_k_h_w_n_thread_desc_old.GetStrides());
|
||||
|
||||
// blockwise input copy
|
||||
// format is [C, Hi, Wi, N]
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(in_c_hp_wp_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_CHWN,
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
3,
|
||||
InBlockCopyDataPerAccess_N,
|
||||
InBlockCopyDataPerAccess_N>(
|
||||
{0, hp_block_data_begin, wp_block_data_begin, n_block_data_begin}, {0, 0, 0, 0});
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, KPerBlock]
|
||||
using WeiBlockCopySubLengths_CYXK =
|
||||
Sequence<WeiBlockCopySubLengths_CK::At(0), 1, 1, WeiBlockCopySubLengths_CK::At(1)>;
|
||||
using WeiBlockCopyClusterLengths_CYXK = Sequence<WeiBlockCopyClusterLengths_CK::At(0),
|
||||
1,
|
||||
1,
|
||||
WeiBlockCopyClusterLengths_CK::At(1)>;
|
||||
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(wei_c_y_x_k_global_desc),
|
||||
decltype(wei_c_1_1_k_block_desc),
|
||||
decltype(wei_c_1_1_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_CYXK,
|
||||
WeiBlockCopyClusterLengths_CYXK,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
3,
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
WeiBlockCopyDataPerAccess_K>(
|
||||
{0, 0, 0, k_block_data_begin}, {0, 0, 0, 0});
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<KPerBlock>{},
|
||||
Number<wei_c_1_1_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_k_wn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_wn_block_mtx_desc),
|
||||
decltype(c_k_wn_thread_mtx_desc),
|
||||
0,
|
||||
in_c_h_w_n_block_desc.GetStride(I1),
|
||||
out_k_h_w_n_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// register
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc_old.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_matrix_set_zero(c_k_wn_thread_mtx_desc, p_out_thread);
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock)
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global, p_in_block);
|
||||
blockwise_wei_copy.Run(p_wei_global, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// move along C
|
||||
blockwise_in_copy.MoveSrcSliceWindow(make_multi_index(CPerBlock, 0, 0, 0),
|
||||
True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(make_multi_index(CPerBlock, 0, 0, 0),
|
||||
True);
|
||||
}
|
||||
|
||||
// reset C
|
||||
blockwise_in_copy.MoveSrcSliceWindow(make_multi_index(C, 0, 0, 0), False);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(make_multi_index(C, 0, 0, 0), False);
|
||||
|
||||
// move aling X
|
||||
blockwise_in_copy.MoveSrcSliceWindow(make_multi_index(0, 0, 1, 0), True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(make_multi_index(0, 0, 1, 0), True);
|
||||
}
|
||||
|
||||
// reset X
|
||||
blockwise_in_copy.MoveSrcSliceWindow(make_multi_index(0, 0, X, 0), False);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(make_multi_index(0, 0, X, 0), False);
|
||||
|
||||
// move along Y
|
||||
blockwise_in_copy.MoveSrcSliceWindow(make_multi_index(0, 1, 0, 0), True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(make_multi_index(0, 1, 0, 0), True);
|
||||
}
|
||||
|
||||
// output: register to global mem
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto fwd) {
|
||||
// fwd do nothing but perfect forwarding.
|
||||
// Using this trick to make this lambda a generic lambda, so it won't be compiled until
|
||||
// being instantiated here
|
||||
static_assert(
|
||||
(fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc_old = fwd(out_k_h_w_n_global_desc_old)
|
||||
.Fold(I3, Number<N1>{}, Number<N2>{})
|
||||
.Fold(I2, Number<W1>{}, Number<W2>{})
|
||||
.Fold(I0, Number<K1>{}, Number<K2>{});
|
||||
|
||||
constexpr auto out_10d_global_desc = make_native_tensor_descriptor(
|
||||
out_10d_global_desc_old.GetLengths(), out_10d_global_desc_old.GetStrides());
|
||||
|
||||
constexpr auto out_10d_thread_desc_old = fwd(out_k_h_w_n_thread_desc_old)
|
||||
.Fold(I3, Number<1>{}, Number<N2>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_native_tensor_descriptor(
|
||||
out_10d_thread_desc_old.GetLengths(), out_10d_thread_desc_old.GetStrides());
|
||||
|
||||
Float* p_out_thread_on_global =
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.CalculateOffset({k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin});
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(
|
||||
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
}).Else([&](auto fwd) {
|
||||
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / fwd(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc_old =
|
||||
fwd(out_k_h_w_n_global_desc_old)
|
||||
.Fold(I3, Number<N1>{})
|
||||
.Fold(I2, Number<W1>{}, Number<W2>{}, Number<W3>{})
|
||||
.Fold(I0, Number<K1>{}, Number<K2>{});
|
||||
|
||||
constexpr auto out_10d_global_desc = make_native_tensor_descriptor(
|
||||
out_10d_global_desc_old.GetLengths(), out_10d_global_desc_old.GetStrides());
|
||||
|
||||
constexpr auto out_10d_thread_desc_old =
|
||||
fwd(out_k_h_w_n_thread_desc_old)
|
||||
.Fold(I3, Number<N1>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{}, Number<W3>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_native_tensor_descriptor(
|
||||
out_10d_thread_desc_old.GetLengths(0), out_10d_thread_desc_old.GetStrides());
|
||||
|
||||
Float* p_out_thread_on_global =
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.CalculateOffset({k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin});
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(
|
||||
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,451 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_tensor_slice_copy.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_op.hpp"
|
||||
#include "blockwise_batched_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockReorderSrcSubLengths_NCHW,
|
||||
class InBlockReorderSrcClusterLengths_NCHW,
|
||||
class InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
index_t InBlockReorderDataPerRead_W,
|
||||
index_t InBlockReorderDataPerWrite_N,
|
||||
class WeiBlockCopyClusterLengths_CK, // not used
|
||||
index_t WeiBlockCopyDataPerRead_K,
|
||||
index_t OutThreadCopyDataPerWrite_W>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
|
||||
constexpr index_t N = out_n_k_h_w_global_desc.GetLength(I0);
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
// divide block work: [N, K, Ho, Wo]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock);
|
||||
constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock);
|
||||
constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock);
|
||||
constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock);
|
||||
|
||||
constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
|
||||
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
|
||||
const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock;
|
||||
const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// global tensor view
|
||||
constexpr auto wei_c_k_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{});
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockReorderDataPerWrite_N,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{},
|
||||
Number<InBlockReorderDataPerWrite_N>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
|
||||
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
|
||||
|
||||
const auto blockwise_in_copy_reorder = BlockwiseTensorSliceReorderCopy_v3<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>({0, 0, 0, 0}, {0, 0, 0, 0});
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead_K>({0, 0}, {0, 0});
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_k_wn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_wn_block_mtx_desc),
|
||||
decltype(c_k_wn_thread_mtx_desc),
|
||||
0,
|
||||
in_c_h_w_n_block_desc.GetStride(I1),
|
||||
out_k_h_w_n_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// choose GEMM implementation here
|
||||
const auto run_blockwise_batch_gemm = [&](auto... Xs) {
|
||||
#if 1
|
||||
return blockwise_batch_gemm.Run(Xs...);
|
||||
#elif 0
|
||||
return blockwise_batch_gemm.Run_amd_asm(Xs...);
|
||||
#else
|
||||
return blockwise_batch_gemm.Run_asm_v2(Xs...);
|
||||
#endif
|
||||
};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space =
|
||||
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_k_block_desc, "wei_c_k_block_desc");
|
||||
|
||||
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
|
||||
}
|
||||
#endif
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_generic_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
|
||||
#if 0
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
blockwise_in_copy_reorder.Run(p_in_global_block_offset +
|
||||
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x),
|
||||
p_in_block);
|
||||
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
run_blockwise_batch_gemm(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset +=
|
||||
CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
|
||||
p_wei_global_block_offset +=
|
||||
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
blockwise_in_copy_reorder.Run(p_in_global_block_offset, p_in_block);
|
||||
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
run_blockwise_batch_gemm(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto fwd) {
|
||||
// fwd do nothing but perfect forwarding.
|
||||
// Using this trick to make this lambda a generic lambda, so it won't be compiled until
|
||||
// begin instantiated here
|
||||
static_assert(
|
||||
(fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc = fwd(out_n_k_h_w_global_desc)
|
||||
.Fold(I3, Number<W1>{}, Number<W2>{})
|
||||
.Fold(I1, Number<K1>{}, Number<K2>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc)
|
||||
.Fold(I3, Number<1>{}, Number<N2>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"a: out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_n_k_h_w_global_desc,
|
||||
"a: out_n_k_h_w_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 3, 4, 5, 6>{};
|
||||
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v2(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_n_k_h_w_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
map_out_global2thread);
|
||||
// Number<OutThreadCopyDataPerWrite_W>{});
|
||||
}).Else([&](auto fwd) {
|
||||
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / fwd(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
fwd(out_n_k_h_w_global_desc)
|
||||
.Fold(I3, Number<W1>{}, Number<W2>{}, Number<W3>{})
|
||||
.Fold(I1, Number<K1>{}, Number<K2>{})
|
||||
.Fold(I0, Number<N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc =
|
||||
fwd(out_k_h_w_n_thread_desc)
|
||||
.Fold(I3, Number<N1>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{}, Number<W3>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"b: out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_n_k_h_w_global_desc,
|
||||
"b: out_n_k_h_w_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{};
|
||||
|
||||
#if 0
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v3(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_n_k_h_w_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
map_out_global2thread,
|
||||
Number<OutThreadCopyDataPerWrite_W>{});
|
||||
#else
|
||||
threadwise_generic_tensor_slice_copy_v1(
|
||||
out_10d_thread_desc.ReorderGivenNew2Old(map_out_global2thread),
|
||||
p_out_thread,
|
||||
make_zero_array<index_t, 10>(),
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_n_k_h_w_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
make_zero_array<index_t, 10>(),
|
||||
out_10d_thread_desc.GetLengths().ReorderGivenNew2Old(map_out_global2thread),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type{},
|
||||
Number<1>{});
|
||||
#endif
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,499 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_tensor_slice_copy.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_op.hpp"
|
||||
#include "blockwise_batched_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockReorderSrcSubLengths_NCHW,
|
||||
class InBlockReorderSrcClusterLengths_NCHW,
|
||||
class InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
index_t InBlockReorderDataPerRead_W,
|
||||
index_t InBlockReorderDataPerWrite_N,
|
||||
class WeiBlockCopyClusterLengths_CK, // not used
|
||||
index_t WeiBlockCopyDataPerRead_K,
|
||||
index_t OutThreadCopyDataPerWrite_W>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
|
||||
constexpr index_t N = out_n_k_h_w_global_desc.GetLength(I0);
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
// assert for LDS double buffer
|
||||
static_assert(C % (2 * CPerBlock) == 0, "C cannot be evenly divided");
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock);
|
||||
constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock);
|
||||
constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock);
|
||||
constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock);
|
||||
|
||||
constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
|
||||
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
|
||||
const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock;
|
||||
const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// global tensor view
|
||||
constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3);
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockReorderDataPerWrite_N,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{},
|
||||
Number<InBlockReorderDataPerWrite_N>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with multiple alignment requirements
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
|
||||
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
|
||||
|
||||
const auto blockwise_in_copy_reorder = BlockwiseTensorSliceReorderCopy_v3<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>({0, 0, 0, 0}, {0, 0, 0, 0});
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead_K>({0, 0}, {0, 0});
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_k_wn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_wn_block_mtx_desc),
|
||||
decltype(c_k_wn_thread_mtx_desc),
|
||||
0,
|
||||
in_c_h_w_n_block_desc.GetStride(I1),
|
||||
out_k_h_w_n_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// choose GEMM implementation here
|
||||
const auto run_blockwise_batch_gemm = [&](auto... Xs) {
|
||||
#if 1
|
||||
return blockwise_batch_gemm.Run(Xs...);
|
||||
#elif 0
|
||||
return blockwise_batch_gemm.Run_amd_asm(Xs...);
|
||||
#else
|
||||
return blockwise_batch_gemm.Run_asm_v2(Xs...);
|
||||
#endif
|
||||
};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space =
|
||||
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
// LDS double buffer
|
||||
__shared__ Float p_in_block_double[2 * in_block_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
// register
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_k_block_desc, "wei_c_k_block_desc");
|
||||
|
||||
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
|
||||
}
|
||||
#endif
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_generic_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin);
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
Float p_in_register_buffer[blockwise_in_copy_reorder.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
blockwise_in_copy_reorder.RunLoadRegisterBuffer(p_in_global_block_offset,
|
||||
p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
|
||||
p_wei_register_buffer);
|
||||
|
||||
blockwise_in_copy_reorder.RunStoreRegisterBuffer(p_in_register_buffer,
|
||||
p_in_block_double);
|
||||
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
|
||||
p_wei_block_double);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C;
|
||||
c_block_data_begin += 2 * CPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_in_block_now =
|
||||
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
|
||||
|
||||
Float* p_in_block_next =
|
||||
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
Float
|
||||
p_in_register_buffer[blockwise_in_copy_reorder.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
p_in_global_block_offset +=
|
||||
CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
|
||||
p_wei_global_block_offset +=
|
||||
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy_reorder.RunLoadRegisterBuffer(p_in_global_block_offset,
|
||||
p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
|
||||
p_wei_register_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
run_blockwise_batch_gemm(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy_reorder.RunStoreRegisterBuffer(p_in_register_buffer,
|
||||
p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
|
||||
p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
Float p_in_register_buffer[blockwise_in_copy_reorder.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
// even iteration
|
||||
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy_reorder.RunLoadRegisterBuffer(p_in_global_block_offset,
|
||||
p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
|
||||
p_wei_register_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
run_blockwise_batch_gemm(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy_reorder.RunStoreRegisterBuffer(
|
||||
p_in_register_buffer, p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
// odd iteration
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
run_blockwise_batch_gemm(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto fwd) {
|
||||
// fwd do nothing but perfect forwarding.
|
||||
// Using this trick to make this lambda a generic lambda, so it won't be compiled until
|
||||
// begin instantiated here
|
||||
static_assert(
|
||||
(fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc = fwd(out_n_k_h_w_global_desc)
|
||||
.Fold(I3, Number<W1>{}, Number<W2>{})
|
||||
.Fold(I1, Number<K1>{}, Number<K2>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc)
|
||||
.Fold(I3, Number<1>{}, Number<N2>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"a: out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_n_k_h_w_global_desc,
|
||||
"a: out_n_k_h_w_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 3, 4, 5, 6>{};
|
||||
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v2(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_n_k_h_w_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
map_out_global2thread);
|
||||
// Number<OutThreadCopyDataPerWrite_W>{});
|
||||
}).Else([&](auto fwd) {
|
||||
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / fwd(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
fwd(out_n_k_h_w_global_desc)
|
||||
.Fold(I3, Number<W1>{}, Number<W2>{}, Number<W3>{})
|
||||
.Fold(I1, Number<K1>{}, Number<K2>{})
|
||||
.Fold(I0, Number<N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc =
|
||||
fwd(out_k_h_w_n_thread_desc)
|
||||
.Fold(I3, Number<N1>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{}, Number<W3>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"b: out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_n_k_h_w_global_desc,
|
||||
"b: out_n_k_h_w_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{};
|
||||
|
||||
#if 0
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v3(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_n_k_h_w_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
map_out_global2thread,
|
||||
Number<OutThreadCopyDataPerWrite_W>{});
|
||||
#else
|
||||
threadwise_generic_tensor_slice_copy_v1(
|
||||
out_10d_thread_desc.ReorderGivenNew2Old(map_out_global2thread),
|
||||
p_out_thread,
|
||||
make_zero_array<index_t, 10>(),
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_n_k_h_w_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
make_zero_array<index_t, 10>(),
|
||||
out_10d_thread_desc.GetLengths().ReorderGivenNew2Old(map_out_global2thread),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type{},
|
||||
Number<1>{});
|
||||
#endif
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
#endif
|
||||
@@ -1,283 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// define B = flatten(N, Hi, Wi)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t BPerThread,
|
||||
index_t KPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
index_t InBlockCopyThreadPerDim0,
|
||||
index_t InBlockCopyThreadPerDim1,
|
||||
index_t WeiBlockCopyThreadPerDim0,
|
||||
index_t WeiBlockCopyThreadPerDim1,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
index_t OutThreadCopyDataPerWrite>
|
||||
struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_chwn_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
|
||||
constexpr index_t Hi = in_chwn_global_desc.GetLength(I1);
|
||||
constexpr index_t Wi = in_chwn_global_desc.GetLength(I2);
|
||||
constexpr index_t N = in_chwn_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t B = N * Hi * Wi;
|
||||
constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
|
||||
|
||||
// divide block work by 2d: [K, B]
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock;
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / BBlockWork;
|
||||
const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
|
||||
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t b_block_data_begin = b_block_work_id * BPerBlock;
|
||||
|
||||
// flattend (2d) tensor view of gridwise input
|
||||
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
|
||||
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
|
||||
|
||||
// tensor view of blockwise input and weight
|
||||
// be careful of alignment
|
||||
constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, BPerBlock + BGhostRead>{}, Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock * Y * X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, Y, X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_kb_thread_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<KPerThread, BPerThread>{});
|
||||
|
||||
// blockwise in copy
|
||||
// formmat is [CPerBlock,BPerBlock + BGhostRead]
|
||||
#if 0
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths())>{};
|
||||
#elif 0
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1>{};
|
||||
#elif 1
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock*Y*X,KPerBlock]
|
||||
#if 0
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths())>{};
|
||||
#elif 0
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1>{};
|
||||
#elif 1
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
// a series of blockwise GEMM
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
|
||||
// a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K]
|
||||
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
|
||||
// c_mtx[K,B] is out_block[K,B]
|
||||
constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_cyxk_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<BPerBlock>{}, Number<in_cb_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_kxb_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{}, Number<BPerThread>{});
|
||||
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxb_block_mtx_desc),
|
||||
decltype(c_kxb_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t max_align =
|
||||
math::lcm(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead);
|
||||
|
||||
constexpr index_t in_block_space = in_cb_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
wei_cyxk_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global + in_cb_global_desc.GetOffsetFromMultiIndex(0, b_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_cyxk_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_kb_thread_desc.GetElementSpace()];
|
||||
|
||||
// set threadwise output to 0
|
||||
threadwise_matrix_set_zero(c_kxb_thread_mtx_desc, p_out_thread);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
|
||||
__syncthreads())
|
||||
{
|
||||
// load data
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
|
||||
p_wei_register_buffer);
|
||||
|
||||
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block);
|
||||
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// compute on current data
|
||||
// a series of GEMM
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
#if 1
|
||||
blockwise_gemm.Run
|
||||
#elif 0
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#elif 1
|
||||
blockwise_gemm.Run_amd_asm
|
||||
#endif
|
||||
(p_wei_block + wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_in_block + y * Wi + x,
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
|
||||
const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
|
||||
|
||||
for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
|
||||
{
|
||||
for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
|
||||
{
|
||||
const auto c_thread_mtx_distance =
|
||||
blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b);
|
||||
|
||||
index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row;
|
||||
index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col;
|
||||
|
||||
index_t h_data = b_data / (Wi * N);
|
||||
index_t itmp = b_data - h_data * (Wi * N);
|
||||
index_t w_data = itmp / N;
|
||||
index_t n_data = itmp - w_data * N;
|
||||
|
||||
if(n_data < N && h_data < Ho && w_data < Wo)
|
||||
{
|
||||
p_out_global[out_khwn_global_desc.GetOffsetFromMultiIndex(
|
||||
k_data, h_data, w_data, n_data)] =
|
||||
p_out_thread[out_kb_thread_desc.GetOffsetFromMultiIndex(k, b)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,408 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// define B = flatten(N, Hi, Wi)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t BPerThread,
|
||||
index_t KPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
index_t InBlockCopyThreadPerDim0,
|
||||
index_t InBlockCopyThreadPerDim1,
|
||||
index_t WeiBlockCopyThreadPerDim0,
|
||||
index_t WeiBlockCopyThreadPerDim1,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
index_t OutThreadCopyDataPerWrite>
|
||||
struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_chwn_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
|
||||
constexpr index_t Hi = in_chwn_global_desc.GetLength(I1);
|
||||
constexpr index_t Wi = in_chwn_global_desc.GetLength(I2);
|
||||
constexpr index_t N = in_chwn_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t B = N * Hi * Wi;
|
||||
constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
|
||||
|
||||
// assert for LDS double buffer
|
||||
static_assert(C % (2 * CPerBlock) == 0, "C cannot be evenly divided");
|
||||
|
||||
// divide block work by 2d: [K, B]
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock;
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / BBlockWork;
|
||||
const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
|
||||
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t b_block_data_begin = b_block_work_id * BPerBlock;
|
||||
|
||||
// flattend (2d) tensor view of gridwise input
|
||||
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
|
||||
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
|
||||
|
||||
// tensor view of blockwise input and weight
|
||||
// be careful of alignment
|
||||
constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, BPerBlock + BGhostRead>{}, Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock * Y * X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, Y, X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_kb_thread_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<KPerThread, BPerThread>{});
|
||||
|
||||
// blockwise in copy
|
||||
// formmat is [CPerBlock,BPerBlock + BGhostRead]
|
||||
#if 0
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths())>{};
|
||||
#elif 0
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1>{};
|
||||
#elif 1
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock*Y*X,KPerBlock]
|
||||
#if 0
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths())>{};
|
||||
#elif 0
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1>{};
|
||||
#elif 1
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
// a series of blockwise GEMM
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
|
||||
// a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K]
|
||||
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
|
||||
// c_mtx[K,B] is out_block[K,B]
|
||||
constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_cyxk_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<BPerBlock>{}, Number<in_cb_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_kxb_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{}, Number<BPerThread>{});
|
||||
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxb_block_mtx_desc),
|
||||
decltype(c_kxb_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t max_align =
|
||||
math::lcm(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead);
|
||||
|
||||
constexpr index_t in_block_space = in_cb_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
wei_cyxk_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
// LDS double buffer
|
||||
__shared__ Float p_in_block_double[2 * in_block_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global + in_cb_global_desc.GetOffsetFromMultiIndex(0, b_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_cyxk_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
// preload data into LDS
|
||||
{
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
|
||||
p_wei_register_buffer);
|
||||
|
||||
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_double);
|
||||
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_double);
|
||||
}
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_kb_thread_desc.GetElementSpace()];
|
||||
|
||||
// set threadwise output to 0
|
||||
threadwise_matrix_set_zero(c_kxb_thread_mtx_desc, p_out_thread);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C;
|
||||
c_block_data_begin += 2 * CPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_in_block_now =
|
||||
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
|
||||
|
||||
Float* p_in_block_next =
|
||||
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
// load next data
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0);
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset,
|
||||
p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
|
||||
p_wei_register_buffer);
|
||||
|
||||
// compute on current data
|
||||
// a series of GEMM
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
#if 1
|
||||
blockwise_gemm.Run
|
||||
#elif 0
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#elif 0
|
||||
blockwise_gemm.Run_amd_asm
|
||||
#endif
|
||||
(p_wei_block_now +
|
||||
wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_in_block_now + y * Wi + x,
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
|
||||
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
// even
|
||||
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0);
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer);
|
||||
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset,
|
||||
p_wei_register_buffer);
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
#if 1
|
||||
blockwise_gemm.Run
|
||||
#elif 0
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#elif 0
|
||||
blockwise_gemm.Run_amd_asm
|
||||
#endif
|
||||
(p_wei_block_double +
|
||||
wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_in_block_double + y * Wi + x,
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
|
||||
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
// odd
|
||||
__syncthreads();
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
#if 1
|
||||
blockwise_gemm.Run
|
||||
#elif 0
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#elif 0
|
||||
blockwise_gemm.Run_amd_asm
|
||||
#endif
|
||||
(p_wei_block_double + wei_block_space +
|
||||
wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_in_block_double + in_block_space + y * Wi + x,
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
|
||||
const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
|
||||
|
||||
if(Y == 1 && X == 1)
|
||||
{ // pure 1x1 conv (non padding, 1x1 stride)
|
||||
constexpr index_t K2_ = GemmMPerThreadSubC;
|
||||
constexpr index_t K1_ = KPerBlock / KPerThread;
|
||||
constexpr index_t B2_ = GemmNPerThreadSubC;
|
||||
constexpr index_t B1_ = BPerBlock / BPerThread;
|
||||
|
||||
constexpr auto out_6d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1_ * K2_), K1_, K2_, B / (B1_ * B2_), B1_, B2_>{});
|
||||
|
||||
constexpr auto out_6d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerBlock / (K1_ * K2_), 1, K2_, BPerBlock / (B1_ * B2_), 1, B2_>{});
|
||||
|
||||
constexpr auto out_kb_global_desc = make_ConstantTensorDescriptor(Sequence<K, B>{});
|
||||
|
||||
threadwise_6d_tensor_copy(out_6d_thread_desc,
|
||||
p_out_thread,
|
||||
out_6d_global_desc,
|
||||
p_out_global +
|
||||
out_kb_global_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_begin, b_thread_data_begin),
|
||||
out_6d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
|
||||
{
|
||||
for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
|
||||
{
|
||||
const auto c_thread_mtx_distance =
|
||||
blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b);
|
||||
|
||||
index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row;
|
||||
index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col;
|
||||
|
||||
index_t h_data = b_data / (Wi * N);
|
||||
index_t itmp = b_data - h_data * (Wi * N);
|
||||
index_t w_data = itmp / N;
|
||||
index_t n_data = itmp - w_data * N;
|
||||
|
||||
if(n_data < N && h_data < Ho && w_data < Wo)
|
||||
{
|
||||
p_out_global[out_khwn_global_desc.GetOffsetFromMultiIndex(
|
||||
k_data, h_data, w_data, n_data)] =
|
||||
p_out_thread[out_kb_thread_desc.GetOffsetFromMultiIndex(k, b)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,376 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// define B = merge(N0, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t N1,
|
||||
index_t N2,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopySubLengths_C_N1_B_N2,
|
||||
class InBlockCopyClusterLengths_C_N1_B_N2,
|
||||
index_t InBlockCopySrcDataPerRead_B,
|
||||
index_t InBlockCopyDstDataPerWrite_N2,
|
||||
class WeiBlockCopySubLengths_C_K,
|
||||
class WeiBlockCopyClusterLengths_C_K,
|
||||
index_t WeiBlockCopyDataPerAccess_K>
|
||||
struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// this is a mess
|
||||
// TODO: find more elegent way of specifying (or calculating) performance parameters
|
||||
static_assert(N2 == GemmNPerThreadSubC, "wrong!");
|
||||
static_assert((N1 * N2 * BPerBlock) %
|
||||
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
constexpr auto False = integral_constant<bool, false>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Hi = in_n_c_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_n_c_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
|
||||
|
||||
constexpr index_t N0 = N / (N1 * N2);
|
||||
|
||||
constexpr index_t B = N0 * Ho * Wo;
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && C % CPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// memory layout descriptor in device memory [N0, N1, N2, C, H, W]
|
||||
constexpr auto in_n0_n1_n2_c_h_w_global_mem_desc =
|
||||
in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
// merged tensor descriptor in device memory [C, N1, B, N2], src of blockwise copy
|
||||
constexpr auto in_c_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
in_n0_n1_n2_c_h_w_global_mem_desc.Slice(I4, Number<Ho>{}).Slice(I5, Number<Wo>{}),
|
||||
Sequence<3>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 4, 5>{},
|
||||
Sequence<2>{});
|
||||
|
||||
// memory layout descriptor in LDS [C, N1, B, N2], dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_c_n1_b_n2_block_mem_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_n1_b_n2_block_mem_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not satisfied");
|
||||
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_c_n1_b_n2_global_merged_desc),
|
||||
decltype(in_c_n1_b_n2_block_mem_desc),
|
||||
decltype(in_c_n1_b_n2_block_mem_desc.GetLengths()),
|
||||
InBlockCopySubLengths_C_N1_B_N2,
|
||||
InBlockCopyClusterLengths_C_N1_B_N2,
|
||||
Sequence<0, 1, 3, 2>, // thread_arrange_order [C, N1, N2, B]
|
||||
Sequence<1, 3, 0, 2>, // src_access_order [N1, N2, C, B]
|
||||
Sequence<0, 1, 2, 3>, // dst_access_order [C, N1, B, N2]
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3);
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDataPerAccess_K, GemmDataPerReadA)>{});
|
||||
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_C_K,
|
||||
WeiBlockCopyClusterLengths_C_K,
|
||||
Sequence<0, 1>, // thread_arrange_order [C, K]
|
||||
Sequence<0, 1>, // src_access_order [C, K]
|
||||
Sequence<0, 1>, // dst_access_order [C, K]
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
WeiBlockCopyDataPerAccess_K>({0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[CPerBlock, KPerBlock] is in LDS
|
||||
// b_mtx[CPerBlocl, N1 * BPerBlock * N2] is in LDS
|
||||
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_n1bn2_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<N1 * BPerBlock * N2>{},
|
||||
Number<in_c_n1_b_n2_block_mem_desc.GetStride(I0)>{});
|
||||
|
||||
// sanity check
|
||||
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<N1 * N2>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_n1bn2_block_mtx_desc),
|
||||
decltype(c_k0k2_n1n2_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// choose GEMM implementation here
|
||||
const auto run_blockwise_gemm = [&](auto... Xs) {
|
||||
#if 1
|
||||
return blockwise_gemm.Run(Xs...);
|
||||
#else
|
||||
return blockwise_gemm.Run_amd_asm(Xs...);
|
||||
#endif
|
||||
};
|
||||
|
||||
// LDS allocation for input and weight: be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t in_block_space =
|
||||
in_c_n1_b_n2_block_mem_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread);
|
||||
|
||||
#if 0
|
||||
// do work
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
// calculate origin of block input and weight tensor on global memory
|
||||
const Float* p_in_block_on_global =
|
||||
p_in_global + in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x);
|
||||
|
||||
const Float* p_wei_block_on_global =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0);
|
||||
|
||||
for(index_t
|
||||
c_block_data_on_global = 0;
|
||||
c_block_data_on_global < C;
|
||||
c_block_data_on_global += CPerBlock,
|
||||
p_in_block_on_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
|
||||
p_wei_block_on_global += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_block_on_global, p_in_block);
|
||||
blockwise_wei_copy.Run(p_wei_block_on_global, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
run_blockwise_gemm(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
// calculate origin of block input and weight tensor on global memory
|
||||
const Float* p_in_block_on_global =
|
||||
p_in_global + in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x);
|
||||
|
||||
const Float* p_wei_block_on_global =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0);
|
||||
|
||||
for(index_t c_block_data_on_global = 0; c_block_data_on_global < C;
|
||||
c_block_data_on_global += CPerBlock)
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_block_on_global, p_in_block);
|
||||
blockwise_wei_copy.Run(p_wei_block_on_global, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_gemm.Run(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(
|
||||
I0, Number<CPerBlock>{}, True);
|
||||
|
||||
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(
|
||||
I0, Number<CPerBlock>{}, True);
|
||||
}
|
||||
|
||||
// reset C
|
||||
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<C>{}, False);
|
||||
|
||||
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<C>{}, False);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t K0 = K / (K1 * K2);
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output memory layout descriptor in register
|
||||
constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc =
|
||||
make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{});
|
||||
|
||||
// output tensor descriptor in register, src of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc =
|
||||
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old(
|
||||
Sequence<4, 3, 7, 0, 1, 2, 5, 6>{});
|
||||
|
||||
// output memory layout descriptor in device memory, dst of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc =
|
||||
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
|
||||
|
||||
// output merged global tensor descriptor, for calculating origin of thread tensor
|
||||
// in global memory
|
||||
constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5),
|
||||
Sequence<3>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 4, 5>{},
|
||||
Sequence<2>{});
|
||||
|
||||
// origin of dst in device memory
|
||||
Float* p_out_thread_on_global =
|
||||
p_out_global +
|
||||
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
|
||||
|
||||
threadwise_generic_tensor_slice_copy_v1(
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc,
|
||||
p_out_thread,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0},
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc,
|
||||
p_out_thread_on_global,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0},
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
|
||||
arithmetic_sequence_gen<0, 8, 1>::type{},
|
||||
Number<1>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,394 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// define B = merge(N0, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t N1,
|
||||
index_t N2,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopySubLengths_C_N1_B_N2,
|
||||
class InBlockCopyClusterLengths_C_N1_B_N2,
|
||||
index_t InBlockCopySrcDataPerRead_B,
|
||||
index_t InBlockCopyDstDataPerWrite_N2,
|
||||
class WeiBlockCopySubLengths_C_K,
|
||||
class WeiBlockCopyClusterLengths_C_K,
|
||||
index_t WeiBlockCopyDataPerAccess_K>
|
||||
struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// this is a mess
|
||||
// TODO: find more elegent way of specifying (or calculating) performance parameters
|
||||
static_assert(N2 == GemmNPerThreadSubC, "wrong!");
|
||||
static_assert((N1 * N2 * BPerBlock) %
|
||||
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Hi = in_n_c_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_n_c_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
|
||||
|
||||
constexpr index_t N0 = N / (N1 * N2);
|
||||
|
||||
constexpr index_t B = N0 * Ho * Wo;
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && C % (2 * CPerBlock) == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// memory layout descriptor in device memory [N0, N1, N2, C, H, W]
|
||||
constexpr auto in_n0_n1_n2_c_h_w_global_mem_desc =
|
||||
in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
// merged tensor descriptor in device memory [C, N1, B, N2], src of blockwise copy
|
||||
constexpr auto in_c_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
in_n0_n1_n2_c_h_w_global_mem_desc.Slice(I4, Number<Ho>{}).Slice(I5, Number<Wo>{}),
|
||||
Sequence<3>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 4, 5>{},
|
||||
Sequence<2>{});
|
||||
|
||||
// memory layout descriptor in LDS [C, N1, B, N2], dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_c_n1_b_n2_block_mem_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_n1_b_n2_block_mem_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not satisfied");
|
||||
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
const auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_c_n1_b_n2_global_merged_desc),
|
||||
decltype(in_c_n1_b_n2_block_mem_desc),
|
||||
decltype(in_c_n1_b_n2_block_mem_desc.GetLengths()),
|
||||
InBlockCopySubLengths_C_N1_B_N2,
|
||||
InBlockCopyClusterLengths_C_N1_B_N2,
|
||||
Sequence<0, 1, 3, 2>, // thread_arrange_order [C, N1, N2, B]
|
||||
Sequence<1, 3, 0, 2>, // src_access_order [N1, N2, C, B]
|
||||
Sequence<0, 1, 2, 3>, // dst_access_order [C, N1, B, N2]
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3);
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDataPerAccess_K, GemmDataPerReadA)>{});
|
||||
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
const auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_C_K,
|
||||
WeiBlockCopyClusterLengths_C_K,
|
||||
Sequence<0, 1>, // thread_arrange_order [C, K]
|
||||
Sequence<0, 1>, // src_access_order [C, K]
|
||||
Sequence<0, 1>, // dst_access_order [C, K]
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
WeiBlockCopyDataPerAccess_K>({0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[CPerBlock, KPerBlock] is in LDS
|
||||
// b_mtx[CPerBlocl, N1 * BPerBlock * N2] is in LDS
|
||||
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_n1bn2_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<N1 * BPerBlock * N2>{},
|
||||
Number<in_c_n1_b_n2_block_mem_desc.GetStride(I0)>{});
|
||||
|
||||
// sanity check
|
||||
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<N1 * N2>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_n1bn2_block_mtx_desc),
|
||||
decltype(c_k0k2_n1n2_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS allocation for input and weight: be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t in_block_space =
|
||||
math::integer_least_multiple(in_c_n1_b_n2_block_mem_desc.GetElementSpace(), max_align);
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
math::integer_least_multiple(wei_c_k_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
__shared__ Float p_in_block_double[2 * in_block_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread);
|
||||
|
||||
// do work
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
// calculate origin of block input and weight tensor on global memory
|
||||
const Float* p_in_block_on_global =
|
||||
p_in_global + in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x);
|
||||
|
||||
const Float* p_wei_block_on_global =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0);
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_block_on_global, p_in_block_double);
|
||||
blockwise_wei_copy.Run(p_wei_block_on_global, p_wei_block_double);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C;
|
||||
c_block_data_begin += 2 * CPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_in_block_now =
|
||||
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
|
||||
|
||||
Float* p_in_block_next =
|
||||
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
p_in_block_on_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
|
||||
p_wei_block_on_global += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterBuffer(p_in_block_on_global,
|
||||
p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
|
||||
p_wei_register_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
|
||||
p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
|
||||
p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
// even iteration
|
||||
p_in_block_on_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
|
||||
p_wei_block_on_global += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterBuffer(p_in_block_on_global,
|
||||
p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
|
||||
p_wei_register_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
// odd iteration
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t K0 = K / (K1 * K2);
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output memory layout descriptor in register
|
||||
constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc =
|
||||
make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{});
|
||||
|
||||
// output tensor descriptor in register, src of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc =
|
||||
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old(
|
||||
Sequence<4, 3, 7, 0, 1, 2, 5, 6>{});
|
||||
|
||||
// output memory layout descriptor in device memory, dst of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc =
|
||||
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
|
||||
|
||||
// output merged global tensor descriptor, for calculating origin of thread tensor
|
||||
// in global memory
|
||||
constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5),
|
||||
Sequence<3>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 4, 5>{},
|
||||
Sequence<2>{});
|
||||
|
||||
// origin of dst in device memory
|
||||
Float* p_out_thread_on_global =
|
||||
p_out_global +
|
||||
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
|
||||
|
||||
threadwise_generic_tensor_slice_copy_v1(
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc,
|
||||
p_out_thread,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0},
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc,
|
||||
p_out_thread_on_global,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0},
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
|
||||
arithmetic_sequence_gen<0, 8, 1>::type{},
|
||||
Number<1>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namesspace ck
|
||||
#endif
|
||||
@@ -8,53 +8,9 @@
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
#include "convolution_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <ConvolutionDirection>
|
||||
struct make_wei_e_k_global_desc_v4r1;
|
||||
|
||||
template <>
|
||||
struct make_wei_e_k_global_desc_v4r1<ConvolutionDirection::Forward>
|
||||
{
|
||||
template <typename WeiDesc>
|
||||
__device__ constexpr auto operator()(WeiDesc) const
|
||||
{
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
return reorder_tensor_descriptor_given_upper2lower(
|
||||
unfold_tensor_descriptor(WeiDesc{}, I1, I3), Sequence<1, 0>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct make_wei_e_k_global_desc_v4r1<ConvolutionDirection::BackwardWeight>
|
||||
{
|
||||
template <typename WeiDesc>
|
||||
__device__ constexpr auto operator()(WeiDesc) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiDesc{};
|
||||
|
||||
constexpr index_t K = wei_k_c_y_x_global_desc.GetLength(I0);
|
||||
constexpr index_t C = wei_k_c_y_x_global_desc.GetLength(I1);
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I2, I3),
|
||||
make_tuple(Merge<Sequence<C, Y * X>>{}, PassThrough<K>{}),
|
||||
make_tuple(Sequence<1, 2>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
@@ -66,18 +22,17 @@ template <index_t GridSize,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
ConvolutionDirection ConvDirection,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t GemmNRepeat,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMPerThread,
|
||||
index_t GemmNPerThread,
|
||||
index_t GemmKPerThread,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
typename InBlockCopySubLengths_E_N1_B_N2,
|
||||
@@ -107,19 +62,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
static_assert(ConvDirection == ConvolutionDirection::Forward ||
|
||||
ConvDirection == ConvolutionDirection::BackwardWeight,
|
||||
"wrong! this kernel only support convolution forward and backward-weight");
|
||||
|
||||
// this is a mess
|
||||
// TODO: find more elegent way of specifying (or calculating) performance parameters
|
||||
constexpr index_t N1 = GemmNRepeat;
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N2 = GemmNPerThread;
|
||||
|
||||
static_assert((N1 * N2 * BPerBlock) %
|
||||
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
static_assert(
|
||||
(N1 * N2 * BPerBlock) % (GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
@@ -240,7 +190,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
// It is constructed differently, depending on whether forward or backward weight
|
||||
// convolution
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
make_wei_e_k_global_desc_v4r1<ConvDirection>{}(wei_k_c_y_x_global_desc);
|
||||
transform_tensor_descriptor(unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I2, I3),
|
||||
make_tuple(Merge<Sequence<C, Y * X>>{}, PassThrough<K>{}),
|
||||
make_tuple(Sequence<1, 2>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// block tensor in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
@@ -290,30 +243,29 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
in_e_n1_b_n2_block_desc.GetStride(I0));
|
||||
|
||||
// sanity check
|
||||
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
|
||||
0,
|
||||
static_assert(KPerBlock % (GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
KPerBlock / (GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k1_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
|
||||
Number<GemmMRepeat * GemmMPerThread>{}, Number<GemmNRepeat * GemmNPerThread>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_e_k_block_mtx_desc),
|
||||
decltype(b_e_n1bn2_block_mtx_desc),
|
||||
decltype(c_k0k1_n1n2_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
@@ -432,13 +384,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t K1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t K0 = K / K1;
|
||||
|
||||
// define output tensor descriptor for threadwise copy
|
||||
// thread output tensor, src of threadwise copy
|
||||
constexpr auto out_k0_k1_n1_b_n2_thread_desc = make_native_tensor_descriptor_packed(
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{});
|
||||
Sequence<GemmMRepeat, GemmMPerThread, N1, 1, N2>{});
|
||||
|
||||
// global output tensor
|
||||
constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
|
||||
@@ -1,460 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
|
||||
#include "convolution_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <ConvolutionDirection>
|
||||
struct make_wei_e_k_global_desc_v4r1_deprecated;
|
||||
|
||||
template <>
|
||||
struct make_wei_e_k_global_desc_v4r1_deprecated<ConvolutionDirection::Forward>
|
||||
{
|
||||
template <typename WeiDesc>
|
||||
__device__ constexpr auto operator()(WeiDesc) const
|
||||
{
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
return WeiDesc::Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct make_wei_e_k_global_desc_v4r1_deprecated<ConvolutionDirection::BackwardWeight>
|
||||
{
|
||||
template <typename WeiDesc>
|
||||
__device__ constexpr auto operator()(WeiDesc) const
|
||||
{
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
return make_ConstantMergedTensorDescriptor(
|
||||
WeiDesc::Unfold(I2, I3), Sequence<1, 2>{}, Sequence<0>{});
|
||||
}
|
||||
};
|
||||
|
||||
// define B = merge(N0, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class AccDataType,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
ConvolutionDirection ConvDirection,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t GemmNRepeat,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopySubLengths_E_N1_B_N2,
|
||||
class InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
class InBlockCopyThreadClusterArrangeOrder,
|
||||
class InBlockCopySrcAccessOrder,
|
||||
class InBlockCopyDstAccessOrder,
|
||||
index_t InBlockCopySrcDataPerRead_B,
|
||||
index_t InBlockCopyDstDataPerWrite_N2,
|
||||
class WeiBlockCopySubLengths_E_K,
|
||||
class WeiBlockCopyClusterLengths_E_K,
|
||||
class WeiBlockCopyThreadClusterArrangeOrder,
|
||||
class WeiBlockCopySrcAccessOrder,
|
||||
class WeiBlockCopyDstAccessOrder,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K>
|
||||
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::Generic>{};
|
||||
constexpr auto global_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::Global>{};
|
||||
|
||||
static_assert(ConvDirection == ConvolutionDirection::Forward ||
|
||||
ConvDirection == ConvolutionDirection::BackwardWeight,
|
||||
"wrong! this kernel only support convolution forward and backward-weight");
|
||||
|
||||
// this is a mess
|
||||
// TODO: find more elegent way of specifying (or calculating) performance parameters
|
||||
constexpr index_t N1 = GemmNRepeat;
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
|
||||
static_assert((N1 * N2 * BPerBlock) %
|
||||
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
|
||||
|
||||
constexpr index_t N0 = N / (N1 * N2);
|
||||
|
||||
constexpr index_t B = N0 * Ho * Wo;
|
||||
|
||||
constexpr index_t E = C * Y * X;
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert(
|
||||
(Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) &&
|
||||
(X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
|
||||
"wrong! alignment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
|
||||
constexpr auto in_n0_n1_n2_h_w_global_desc =
|
||||
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrideH>{})
|
||||
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrideW>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{})
|
||||
.Extract(Sequence<0, 1, 2, 4, 5>{});
|
||||
|
||||
// batch descritpor for device memory
|
||||
constexpr auto in_c_y_x_global_desc =
|
||||
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
|
||||
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
|
||||
.Extract(Sequence<1, 2, 3>{});
|
||||
|
||||
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
|
||||
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_global_desc),
|
||||
Sequence<0, 1, 2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<3, 6, 7>{},
|
||||
Sequence<5>{});
|
||||
|
||||
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with multiple alignment
|
||||
// requirements
|
||||
static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not satisfied");
|
||||
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2_deprecated<
|
||||
BlockSize,
|
||||
decltype(in_e_n1_b_n2_global_merged_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
2,
|
||||
3,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
|
||||
|
||||
// weight tensor
|
||||
// Iensor descriptor in device memory, src of blockwise copy
|
||||
// It is constructed differently, depending on whether forward or backward weight
|
||||
// convolution
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
make_wei_e_k_global_desc_v4r1_deprecated<ConvDirection>{}(wei_k_c_y_x_global_desc);
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<EPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
|
||||
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
0,
|
||||
1,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[EPerBlock, KPerBlock] is in LDS
|
||||
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
|
||||
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
|
||||
|
||||
constexpr auto b_e_n1bn2_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(in_e_n1_b_n2_block_desc.Unfold(I1, I3));
|
||||
|
||||
// sanity check
|
||||
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k1_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_e_k_block_mtx_desc),
|
||||
decltype(b_e_n1bn2_block_mtx_desc),
|
||||
decltype(c_k0k1_n1n2_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS allocation for input and weight: be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t in_block_space =
|
||||
math::integer_least_multiple(in_e_n1_b_n2_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
__shared__ Float p_in_block_double[2 * in_block_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
AccDataType p_out_thread[c_k0k1_n1n2_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k1_n1n2_thread_mtx_desc, p_out_thread);
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.Run(
|
||||
p_in_global, p_in_block_double, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.Run(
|
||||
p_wei_global, p_wei_block_double, global_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
|
||||
e_block_data_begin += 2 * EPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_in_block_now =
|
||||
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
|
||||
|
||||
Float* p_in_block_next =
|
||||
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadThreadBuffer(
|
||||
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(
|
||||
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0);
|
||||
|
||||
if(has_two_iteration_left) // if has 2 iteration left
|
||||
{
|
||||
// even iteration
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadThreadBuffer(
|
||||
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(
|
||||
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
// odd iteration
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
}
|
||||
}
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output memory layout descriptor in register, src of threadwise copy
|
||||
constexpr auto out_k0_k1_n1_b_n2_thread_mem_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{});
|
||||
|
||||
// output memory layout descriptor in device memory
|
||||
constexpr auto out_n0_n1_n2_k0_k1_h_w_global_mem_desc =
|
||||
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}).Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
// output merged global tensor descriptor, dst of threadwise copy
|
||||
constexpr auto out_k0_k1_n1_b_n2_global_merged_desc =
|
||||
make_ConstantMergedTensorDescriptor(out_n0_n1_n2_k0_k1_h_w_global_mem_desc,
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 5, 6>{},
|
||||
Sequence<2>{});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<
|
||||
decltype(out_k0_k1_n1_b_n2_thread_mem_desc),
|
||||
decltype(out_k0_k1_n1_b_n2_global_merged_desc),
|
||||
decltype(out_k0_k1_n1_b_n2_thread_mem_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 5, 1>::type,
|
||||
arithmetic_sequence_gen<0, 5, 1>::type,
|
||||
3,
|
||||
3,
|
||||
1,
|
||||
1>({0, 0, 0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
0,
|
||||
b_thread_data_on_global,
|
||||
0})
|
||||
.Run(p_out_thread, p_out_global, generic_address_space, global_address_space);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif // CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP
|
||||
@@ -1,432 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
index_t N1,
|
||||
index_t N2,
|
||||
index_t Ho1,
|
||||
index_t Ho2,
|
||||
index_t Wo1,
|
||||
index_t Wo2,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopySubLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2,
|
||||
class InBlockCopyClusterLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2,
|
||||
class InBlockCopyThreadClusterArrangeOrder,
|
||||
class InBlockCopySrcAccessOrder,
|
||||
class InBlockCopyDstAccessOrder,
|
||||
index_t InBlockCopyDataPerAccess_W2,
|
||||
class WeiBlockCopySubLengths_E_K,
|
||||
class WeiBlockCopyClusterLengths_E_K,
|
||||
class WeiBlockCopyThreadClusterArrangeOrder,
|
||||
class WeiBlockCopySrcAccessOrder,
|
||||
class WeiBlockCopyDstAccessOrder,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K>
|
||||
struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// this is a mess
|
||||
// TODO: find more elegent way of specifying (or calculating) performance parameters
|
||||
static_assert(N2 * Ho2 * Wo2 == GemmNPerThreadSubC, "wrong!");
|
||||
static_assert((N1 * Ho1 * Wo1 * BPerBlock * N2 * Ho2 * Wo2) %
|
||||
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1];
|
||||
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t E = C * Y * X;
|
||||
|
||||
constexpr index_t B = N1 * Ho1 * Wo1;
|
||||
|
||||
static_assert(N % (N1 * N2) == 0 && Ho % (Ho1 * Ho2) == 0 && Wo % (Wo1 * Wo2) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t N0 = N / (N1 * N2);
|
||||
constexpr index_t Ho0 = Ho / (Ho1 * Ho2);
|
||||
constexpr index_t Wo0 = Wo / (Wo1 * Wo2);
|
||||
|
||||
static_assert((X == 1 || ConvDilationW % InBlockCopyDataPerAccess_W2 == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// tensor descriptor in device memory [N0, N1, N2, Ho0, Ho1, Ho2, Wo0, Wo1, Wo2]
|
||||
constexpr auto in_n0_n1_n2_ho0_ho1_ho2_wo0_wo1_wo2_global_desc =
|
||||
in_n_c_h_w_global_desc.Extract(I0, I2, I3)
|
||||
.StridedSlice(I1, Number<Ho>{}, Number<ConvStrideH>{})
|
||||
.StridedSlice(I2, Number<Wo>{}, Number<ConvStrideW>{})
|
||||
.Fold(I2, Number<Wo1>{}, Number<Wo2>{})
|
||||
.Fold(I1, Number<Ho1>{}, Number<Ho2>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
constexpr auto in_n0_ho0_wo0_n1_ho1_wo1_n2_ho2_wo2_global_desc =
|
||||
in_n0_n1_n2_ho0_ho1_ho2_wo0_wo1_wo2_global_desc.ReorderGivenNew2Old(
|
||||
Sequence<0, 3, 6, 1, 4, 7, 2, 5, 8>{});
|
||||
|
||||
// batch descritpor for device memory
|
||||
constexpr auto in_c_y_x_global_desc =
|
||||
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
|
||||
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
|
||||
.Extract(Sequence<1, 2, 3>{});
|
||||
|
||||
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
|
||||
constexpr auto in_e_n0_ho0_wo0_b_n2_ho2_wo2_global_merged_desc =
|
||||
make_ConstantMergedTensorDescriptor(
|
||||
in_c_y_x_global_desc.Embed(in_n0_ho0_wo0_n1_ho1_wo1_n2_ho2_wo2_global_desc),
|
||||
Sequence<0, 1, 2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{},
|
||||
Sequence<6, 7, 8>{},
|
||||
Sequence<9>{},
|
||||
Sequence<10>{},
|
||||
Sequence<11>{});
|
||||
|
||||
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc =
|
||||
make_ConstantTensorDescriptor_packed(
|
||||
Sequence<EPerBlock, N0, Ho0, Wo0, BPerBlock, N2, Ho2, Wo2>{});
|
||||
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_e_n0_ho0_wo0_b_n2_ho2_wo2_global_merged_desc),
|
||||
decltype(in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc),
|
||||
decltype(in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2,
|
||||
InBlockCopyClusterLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
InBlockCopyDataPerAccess_W2,
|
||||
InBlockCopyDataPerAccess_W2>({0, 0, 0, 0, b_block_data_on_global, 0, 0, 0},
|
||||
{0, 0, 0, 0, 0, 0, 0, 0});
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<EPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
|
||||
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1_deprecated<BlockSize,
|
||||
Float,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[EPerBlock, KPerBlock] is in LDS
|
||||
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
|
||||
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with multiple alignment
|
||||
// requirements
|
||||
static_assert(in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc.GetStrides()[3] % GemmDataPerReadB ==
|
||||
0,
|
||||
"GemmDataPerReadB alignment requirement is not satisfied");
|
||||
|
||||
constexpr auto b_e_n0ho0wo0bn2ho2wo2_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc.Unfold(I1, I7));
|
||||
|
||||
// sanity check
|
||||
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k2_n0ho0wo0n2ho2wo2_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor_packed(Number<GemmMRepeat * GemmMPerThreadSubC>{},
|
||||
Number<N0 * Ho0 * Wo0 * N2 * Ho2 * Wo2>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_e_k_block_mtx_desc),
|
||||
decltype(b_e_n0ho0wo0bn2ho2wo2_block_mtx_desc),
|
||||
decltype(c_k0k2_n0ho0wo0n2ho2wo2_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS allocation for input and weight: be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_W2,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t in_block_space = math::integer_least_multiple(
|
||||
in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
__shared__ Float p_in_block_double[2 * in_block_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
Float p_out_thread[c_k0k2_n0ho0wo0n2ho2wo2_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k2_n0ho0wo0n2ho2wo2_thread_mtx_desc, p_out_thread);
|
||||
|
||||
const Float* p_wei_block_on_global = p_wei_global;
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global, p_in_block_double);
|
||||
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
|
||||
e_block_data_begin += 2 * EPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_in_block_now =
|
||||
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
|
||||
|
||||
Float* p_in_block_next =
|
||||
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
|
||||
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
|
||||
p_wei_register_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
// even iteration
|
||||
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
|
||||
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
// odd iteration
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
}
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output memory layout descriptor in register
|
||||
constexpr auto out_k0_k1_k2_n0_ho0_wo0_n1_ho1_wo1_n2_ho2_wo2_thread_mem_desc =
|
||||
make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerBlock / (K1 * K2), 1, K2, N0, Ho0, Wo0, 1, 1, 1, N2, Ho2, Wo2>{});
|
||||
|
||||
// output tensor descriptor in register, src of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc =
|
||||
out_k0_k1_k2_n0_ho0_wo0_n1_ho1_wo1_n2_ho2_wo2_thread_mem_desc.ReorderGivenNew2Old(
|
||||
Sequence<3, 6, 9, 0, 1, 2, 4, 7, 10, 5, 8, 11>{});
|
||||
|
||||
// output memory layout descriptor in device memory, dst of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc =
|
||||
out_n_k_h_w_global_desc.Fold(I3, Sequence<Wo1, Wo2>{})
|
||||
.Fold(I2, Sequence<Ho1, Ho2>{})
|
||||
.Fold(I1, Sequence<K1, K2>{})
|
||||
.Fold(I0, Sequence<N1, N2>{});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / (N2 * Ho2 * Wo2);
|
||||
|
||||
// output merged global tensor descriptor, for calculating origin of thread tensor
|
||||
// in global memory
|
||||
constexpr auto out_k_n0_ho0_wo0_b_n2_ho2_wo2_global_merged_desc =
|
||||
make_ConstantMergedTensorDescriptor(
|
||||
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc.Unfold(I3, I5),
|
||||
Sequence<3>{},
|
||||
Sequence<0>{},
|
||||
Sequence<4>{},
|
||||
Sequence<7>{},
|
||||
Sequence<1, 5, 8>{},
|
||||
Sequence<2>{},
|
||||
Sequence<6>{},
|
||||
Sequence<9>{});
|
||||
|
||||
// origin of dst in device memory
|
||||
Float* p_out_thread_on_global =
|
||||
p_out_global +
|
||||
out_k_n0_ho0_wo0_b_n2_ho2_wo2_global_merged_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_on_global, 0, 0, 0, b_thread_data_on_global, 0, 0, 0);
|
||||
|
||||
threadwise_generic_tensor_slice_copy_v1(
|
||||
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc,
|
||||
p_out_thread,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc,
|
||||
p_out_thread_on_global,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc.GetLengths(),
|
||||
arithmetic_sequence_gen<0, 12, 1>::type{},
|
||||
Number<1>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,457 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R3_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R3_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
index_t N0,
|
||||
index_t N1,
|
||||
index_t N2,
|
||||
index_t Ho0,
|
||||
index_t Ho1,
|
||||
index_t Ho2,
|
||||
index_t Wo0,
|
||||
index_t Wo1,
|
||||
index_t Wo2,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopySubLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2,
|
||||
class InBlockCopyClusterLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2,
|
||||
class InBlockCopyThreadClusterArrangeOrder,
|
||||
class InBlockCopySrcAccessOrder,
|
||||
class InBlockCopyDstAccessOrder,
|
||||
index_t InBlockCopyDataPerAccess_W2,
|
||||
class WeiBlockCopySubLengths_E_K,
|
||||
class WeiBlockCopyClusterLengths_E_K,
|
||||
class WeiBlockCopyThreadClusterArrangeOrder,
|
||||
class WeiBlockCopySrcAccessOrder,
|
||||
class WeiBlockCopyDstAccessOrder,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K>
|
||||
struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// this is a mess
|
||||
// TODO: find more elegent way of specifying (or calculating) performance parameters
|
||||
static_assert(N2 * Ho2 * Wo2 == GemmNPerThreadSubC, "wrong!");
|
||||
static_assert((N1 * Ho1 * Wo1 * BPerBlock * N2 * Ho2 * Wo2) %
|
||||
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1];
|
||||
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t E = C * Y * X;
|
||||
|
||||
constexpr index_t B = N0 * Ho0 * Wo0;
|
||||
|
||||
static_assert(N == N0 * N1 * N2 && Ho == Ho0 * Ho1 * Ho2 && Wo == Wo0 * Wo1 * Wo2,
|
||||
"wrong!");
|
||||
|
||||
static_assert((X == 1 || ConvDilationW % InBlockCopyDataPerAccess_W2 == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// tensor descriptor in device memory [N0, N1, N2, Ho0, Ho1, Ho2, Wo0, Wo1, Wo2]
|
||||
constexpr auto in_n0_n1_n2_ho0_ho1_ho2_wo0_wo1_wo2_global_desc =
|
||||
in_n_c_h_w_global_desc.Extract(I0, I2, I3)
|
||||
.StridedSlice(I1, Number<Ho>{}, Number<ConvStrideH>{})
|
||||
.StridedSlice(I2, Number<Wo>{}, Number<ConvStrideW>{})
|
||||
.Fold(I2, Number<Wo1>{}, Number<Wo2>{})
|
||||
.Fold(I1, Number<Ho1>{}, Number<Ho2>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
constexpr auto in_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_global_desc =
|
||||
in_n0_n1_n2_ho0_ho1_ho2_wo0_wo1_wo2_global_desc.ReorderGivenNew2Old(
|
||||
Sequence<1, 4, 7, 0, 3, 6, 2, 5, 8>{});
|
||||
|
||||
// batch descritpor for device memory
|
||||
constexpr auto in_c_y_x_global_desc =
|
||||
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
|
||||
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
|
||||
.Extract(Sequence<1, 2, 3>{});
|
||||
|
||||
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
|
||||
constexpr auto in_e_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc =
|
||||
make_ConstantMergedTensorDescriptor(
|
||||
in_c_y_x_global_desc.Embed(in_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_global_desc),
|
||||
Sequence<0, 1, 2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{},
|
||||
Sequence<6, 7, 8>{},
|
||||
Sequence<9>{},
|
||||
Sequence<10>{},
|
||||
Sequence<11>{});
|
||||
|
||||
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc =
|
||||
make_ConstantTensorDescriptor_packed(
|
||||
Sequence<EPerBlock, N1, Ho1, Wo1, BPerBlock, N2, Ho2, Wo2>{});
|
||||
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_e_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc),
|
||||
decltype(in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc),
|
||||
decltype(in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2,
|
||||
InBlockCopyClusterLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
InBlockCopyDataPerAccess_W2,
|
||||
InBlockCopyDataPerAccess_W2>({0, 0, 0, 0, b_block_data_on_global, 0, 0, 0},
|
||||
{0, 0, 0, 0, 0, 0, 0, 0});
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<EPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
|
||||
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1_deprecated<BlockSize,
|
||||
Float,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("id (%d %d), in offset: %d %d, wei offset %d %d\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
blockwise_in_copy.mThreadSrcOffset,
|
||||
blockwise_in_copy.mThreadDstOffset,
|
||||
blockwise_wei_copy.mThreadSrcOffset,
|
||||
blockwise_wei_copy.mThreadDstOffset);
|
||||
}
|
||||
#endif
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[EPerBlock, KPerBlock] is in LDS
|
||||
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
|
||||
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with multiple alignment
|
||||
// requirements
|
||||
static_assert(in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.GetStrides()[3] % GemmDataPerReadB ==
|
||||
0,
|
||||
"GemmDataPerReadB alignment requirement is not satisfied");
|
||||
|
||||
constexpr auto b_e_n1ho1wo1bn2ho2wo2_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.Unfold(I1, I7));
|
||||
|
||||
// sanity check
|
||||
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor_packed(Number<GemmMRepeat * GemmMPerThreadSubC>{},
|
||||
Number<N1 * Ho1 * Wo1 * N2 * Ho2 * Wo2>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_e_k_block_mtx_desc),
|
||||
decltype(b_e_n1ho1wo1bn2ho2wo2_block_mtx_desc),
|
||||
decltype(c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS allocation for input and weight: be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_W2,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t in_block_space = math::integer_least_multiple(
|
||||
in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
__shared__ Float p_in_block_double[2 * in_block_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
Float p_out_thread[c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc, p_out_thread);
|
||||
|
||||
const Float* p_wei_block_on_global = p_wei_global;
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global, p_in_block_double);
|
||||
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
|
||||
e_block_data_begin += 2 * EPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_in_block_now =
|
||||
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
|
||||
|
||||
Float* p_in_block_next =
|
||||
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
|
||||
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
|
||||
p_wei_register_buffer);
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("tid (%d %d), %f %f %f %f\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
p_wei_register_buffer[0],
|
||||
p_wei_register_buffer[1],
|
||||
p_wei_register_buffer[2],
|
||||
p_wei_register_buffer[3]);
|
||||
}
|
||||
#endif
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
// even iteration
|
||||
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
|
||||
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
// odd iteration
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
}
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output memory layout descriptor in register
|
||||
constexpr auto out_k0_k1_k2_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_thread_mem_desc =
|
||||
make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, Ho1, Wo1, 1, 1, 1, N2, Ho2, Wo2>{});
|
||||
|
||||
// output tensor descriptor in register, src of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc =
|
||||
out_k0_k1_k2_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_thread_mem_desc.ReorderGivenNew2Old(
|
||||
Sequence<6, 3, 9, 0, 1, 2, 7, 4, 10, 8, 5, 11>{});
|
||||
|
||||
// output memory layout descriptor in device memory, dst of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc =
|
||||
out_n_k_h_w_global_desc.Fold(I3, Sequence<Wo1, Wo2>{})
|
||||
.Fold(I2, Sequence<Ho1, Ho2>{})
|
||||
.Fold(I1, Sequence<K1, K2>{})
|
||||
.Fold(I0, Sequence<N1, N2>{});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / (N2 * Ho2 * Wo2);
|
||||
|
||||
// output merged global tensor descriptor, for calculating origin of thread tensor
|
||||
// in global memory
|
||||
constexpr auto out_k_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc =
|
||||
make_ConstantMergedTensorDescriptor(
|
||||
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc.Unfold(I3, I5),
|
||||
Sequence<3>{},
|
||||
Sequence<1>{},
|
||||
Sequence<5>{},
|
||||
Sequence<8>{},
|
||||
Sequence<0, 4, 7>{},
|
||||
Sequence<2>{},
|
||||
Sequence<6>{},
|
||||
Sequence<9>{});
|
||||
|
||||
// origin of dst in device memory
|
||||
Float* p_out_thread_on_global =
|
||||
p_out_global +
|
||||
out_k_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_on_global, 0, 0, 0, b_thread_data_on_global, 0, 0, 0);
|
||||
|
||||
threadwise_generic_tensor_slice_copy_v1(
|
||||
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc,
|
||||
p_out_thread,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc,
|
||||
p_out_thread_on_global,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc.GetLengths(),
|
||||
arithmetic_sequence_gen<0, 12, 1>::type{},
|
||||
Number<1>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -75,6 +75,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
#if 0
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmBBlockCopySrcDataPerRead_GemmN == 1)) &&
|
||||
(X == 1 || ConvDilationW % GemmBBlockCopySrcDataPerRead_GemmN == 0) &&
|
||||
@@ -82,9 +83,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
InRightPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0,
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
#endif
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower(
|
||||
constexpr auto wei_gemmk_gemmm_global_desc = reorder_tensor_descriptor_given_upper2lower(
|
||||
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
|
||||
|
||||
// input tensor
|
||||
@@ -108,14 +110,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_e_b_global_desc = transform_tensor_descriptor(
|
||||
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_k_b_global_desc =
|
||||
constexpr auto out_gemmk_gemmn_global_desc =
|
||||
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
@@ -127,9 +129,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(in_e_b_global_desc),
|
||||
decltype(out_k_b_global_desc),
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
decltype(in_gemmm_gemmn_global_desc),
|
||||
decltype(out_gemmk_gemmn_global_desc),
|
||||
InMemoryDataOperation::Set,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
@@ -157,7 +159,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
1,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<2, 3, 0, 1>,
|
||||
3,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
|
||||
@@ -1,404 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_DEPRECATRD_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_DEPRECATRD_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// B = merge(N, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopySubLengths_E_B,
|
||||
class InBlockCopyClusterLengths_E_B,
|
||||
class InBlockCopyThreadClusterArrangeOrder,
|
||||
class InBlockCopySrcAccessOrder,
|
||||
class InBlockCopyDstAccessOrder,
|
||||
index_t InBlockCopyDataPerAccess_B,
|
||||
class WeiBlockCopySubLengths_E_K,
|
||||
class WeiBlockCopyClusterLengths_E_K,
|
||||
class WeiBlockCopyThreadClusterArrangeOrder,
|
||||
class WeiBlockCopySrcAccessOrder,
|
||||
class WeiBlockCopyDstAccessOrder,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K,
|
||||
index_t OutThreadCopyDataPerAccess_B>
|
||||
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1];
|
||||
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t E = C * Y * X;
|
||||
constexpr index_t B = N * Ho * Wo;
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopyDataPerAccess_B == 1)) &&
|
||||
(X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// tensor descriptor in device memory [N, Ho, Wo]
|
||||
constexpr auto in_n_ho_wo_global_desc =
|
||||
in_n_c_h_w_global_desc.Extract(I0, I2, I3)
|
||||
.StridedSlice(I1, Number<Ho>{}, Number<ConvStrideH>{})
|
||||
.StridedSlice(I2, Number<Wo>{}, Number<ConvStrideW>{});
|
||||
|
||||
// batch descritpor for device memory
|
||||
constexpr auto in_c_y_x_global_desc =
|
||||
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
|
||||
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
|
||||
.Extract(Sequence<1, 2, 3>{});
|
||||
|
||||
// merged tensor descriptor in device memory [E, B], src of blockwise copy
|
||||
constexpr auto in_e_b_global_desc =
|
||||
make_ConstantMergedTensorDescriptor(in_c_y_x_global_desc.Embed(in_n_ho_wo_global_desc),
|
||||
Sequence<0, 1, 2>{},
|
||||
Sequence<3, 4, 5>{});
|
||||
|
||||
// memory layout descriptor in LDS [E, B], dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_e_b_block_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<EPerBlock, BPerBlock>{});
|
||||
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
|
||||
decltype(in_e_b_global_desc),
|
||||
decltype(in_e_b_block_desc),
|
||||
decltype(in_e_b_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_B,
|
||||
InBlockCopyClusterLengths_E_B,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
1,
|
||||
1,
|
||||
InBlockCopyDataPerAccess_B,
|
||||
InBlockCopyDataPerAccess_B>(
|
||||
{0, b_block_data_on_global}, {0, 0});
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<EPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with multiple alignment
|
||||
// requirements
|
||||
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
|
||||
"GemmDataPerReadA alignment requirement is not satisfied");
|
||||
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
0,
|
||||
1,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[EPerBlock, KPerBlock] is in LDS
|
||||
// b_mtx[EPerBlocl, BPerBlock] is in LDS
|
||||
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
|
||||
|
||||
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
|
||||
|
||||
// sanity check
|
||||
static_assert(
|
||||
KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0 &&
|
||||
BPerBlock % (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
|
||||
constexpr index_t GemmNRepeat =
|
||||
BPerBlock / (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k1_b0b1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_e_k_block_mtx_desc),
|
||||
decltype(b_e_b_block_mtx_desc),
|
||||
decltype(c_k0k1_b0b1_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS allocation for input and weight: be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_B,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t in_block_space =
|
||||
math::integer_least_multiple(in_e_b_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
__shared__ Float p_in_block_double[2 * in_block_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
Float p_out_thread[c_k0k1_b0b1_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread);
|
||||
|
||||
const Float* p_wei_block_on_global = p_wei_global;
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.template Run<Float, AddressSpace::Global>(p_in_global,
|
||||
p_in_block_double);
|
||||
blockwise_wei_copy.template Run<Float, AddressSpace::Global>(p_wei_global,
|
||||
p_wei_block_double);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
|
||||
e_block_data_begin += 2 * EPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_in_block_now =
|
||||
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
|
||||
|
||||
Float* p_in_block_next =
|
||||
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
|
||||
p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
|
||||
p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
// even iteration
|
||||
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
|
||||
p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::Global>(
|
||||
p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
// odd iteration
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
}
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output global descriptor, for calculating origin of thread tensor
|
||||
// in global memory
|
||||
constexpr auto out_k_b_global_desc = make_ConstantMergedTensorDescriptor(
|
||||
out_n_k_h_w_global_desc, Sequence<1>{}, Sequence<0, 2, 3>{});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col;
|
||||
|
||||
// This is a hack, because slicing a merged dimension is not supported yet.
|
||||
// This should be replaced with logic above, once slicing a merged dimension support
|
||||
// become available
|
||||
// dst descriptor
|
||||
constexpr auto out_k0_k1_b_global_desc =
|
||||
make_ConstantMergedTensorDescriptor(out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}),
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<0, 3, 4>{});
|
||||
|
||||
// src descriptor
|
||||
constexpr auto out_k0_k1_b_thread_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat * GemmNPerThreadSubC>{});
|
||||
|
||||
using OutThreadCopySliceLengths =
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>;
|
||||
|
||||
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<
|
||||
decltype(out_k0_k1_b_thread_desc),
|
||||
decltype(out_k0_k1_b_global_desc),
|
||||
OutThreadCopySliceLengths,
|
||||
arithmetic_sequence_gen<0, 3, 1>::type,
|
||||
arithmetic_sequence_gen<0, 3, 1>::type,
|
||||
2,
|
||||
2,
|
||||
OutThreadCopyDataPerAccess_B,
|
||||
OutThreadCopyDataPerAccess_B>({0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
b_thread_data_on_global});
|
||||
|
||||
for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat)
|
||||
{
|
||||
threadwise_out_copy
|
||||
.template Run<Float, AddressSpace::Generic, AddressSpace::Global>(p_out_thread,
|
||||
p_out_global);
|
||||
|
||||
threadwise_out_copy.MoveSrcSliceWindow(Sequence<0, 0, GemmNPerThreadSubC>{}, True);
|
||||
threadwise_out_copy.MoveDstSliceWindow(Sequence<0, 0, B1>{}, True);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -2,7 +2,6 @@
|
||||
#define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -58,18 +57,6 @@ __host__ __device__ constexpr auto
|
||||
return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr auto
|
||||
make_ConstantMatrixDescriptor(ConstantTensorDescriptor_deprecated<Ts...>)
|
||||
{
|
||||
using TDesc = ConstantTensorDescriptor_deprecated<Ts...>;
|
||||
static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
|
||||
static_assert(TDesc::GetStrides()[1] == 1, "wrong");
|
||||
return ConstantMatrixDescriptor<TDesc::GetLengths()[0],
|
||||
TDesc::GetLengths()[1],
|
||||
TDesc::GetStrides()[0]>{};
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(NativeTensorDescriptor<Ts...>)
|
||||
{
|
||||
|
||||
@@ -1,210 +0,0 @@
|
||||
#ifndef CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_DEPRECATED_HPP
|
||||
#define CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_DEPRECATED_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// OriginalTensorDesc : ConstantTensorDescriptor_deprecated<...>
|
||||
// it's the tensor whose dimensions are to be merged
|
||||
// OriginalDimMergeSeqs : Sequence<...>...
|
||||
// each is a sequence of original dimensions (of OriginalTensorDesc) to be merged
|
||||
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
|
||||
struct ConstantMergedTensorDescriptor_deprecated
|
||||
{
|
||||
using Type = ConstantMergedTensorDescriptor_deprecated;
|
||||
|
||||
static constexpr auto mOriginalDimMergeSeqs = std::tuple<OriginalDimMergeSeqs...>{};
|
||||
|
||||
static constexpr index_t nDim = sizeof...(OriginalDimMergeSeqs);
|
||||
static constexpr index_t nOriginalDim = OriginalTensorDesc::GetNumOfDimension();
|
||||
|
||||
__host__ __device__ constexpr ConstantMergedTensorDescriptor_deprecated()
|
||||
{
|
||||
static_assert(nDim <= nOriginalDim, "wrong!");
|
||||
|
||||
// TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most
|
||||
// OriginalTensorDesc::nDim number of dimensions
|
||||
|
||||
// TODO: check OriginalDimMergeSeqs contains all original dimensions
|
||||
|
||||
// TODO: check there is no duplication in OriginalDimMergeSeqs
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetOriginalTensorDescriptor()
|
||||
{
|
||||
return OriginalTensorDesc{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
|
||||
{
|
||||
return std::get<IDim>(mOriginalDimMergeSeqs);
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number<IDim>)
|
||||
{
|
||||
return (std::get<IDim>(mOriginalDimMergeSeqs).GetSize() > 1);
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
|
||||
{
|
||||
constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs);
|
||||
|
||||
return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize();
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetStride(Number<IDim>)
|
||||
{
|
||||
static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
|
||||
"wrong! stride of a merged dimension is undefined");
|
||||
|
||||
constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Back();
|
||||
|
||||
return OriginalTensorDesc::GetStride(Number<idim_original>{});
|
||||
}
|
||||
|
||||
// this is a hack to return the stride of the last original dimension of a merged dimension
|
||||
// TODO: refactor this once the concept of "dimension" is used
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetLastOriginalDimensionStride(Number<IDim>)
|
||||
{
|
||||
constexpr auto idim_last_original = std::get<IDim>(mOriginalDimMergeSeqs).Back();
|
||||
|
||||
return OriginalTensorDesc::GetStride(Number<idim_last_original>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths()
|
||||
{
|
||||
return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetElementSize()
|
||||
{
|
||||
return OriginalTensorDesc::GetElementSize();
|
||||
}
|
||||
|
||||
template <class OriginalDimsPartial>
|
||||
struct lambda_1_GetOriginalMultiIndexFromMultiIndex
|
||||
{
|
||||
const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial;
|
||||
Array<index_t, nOriginalDim>& original_multi_id;
|
||||
|
||||
__host__ __device__ constexpr lambda_1_GetOriginalMultiIndexFromMultiIndex(
|
||||
const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial_,
|
||||
Array<index_t, nOriginalDim>& original_multi_id_)
|
||||
: original_multi_id_partial(original_multi_id_partial_),
|
||||
original_multi_id(original_multi_id_)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr void operator()(Number<I>) const
|
||||
{
|
||||
constexpr index_t idim_original = OriginalDimsPartial::Get(Number<I>{});
|
||||
|
||||
index_t itmp = original_multi_id_partial[I];
|
||||
|
||||
original_multi_id(idim_original) = itmp;
|
||||
}
|
||||
};
|
||||
|
||||
struct lambda_0_GetOriginalMultiIndexFromMultiIndex
|
||||
{
|
||||
const Array<index_t, nDim>& multi_id;
|
||||
Array<index_t, nOriginalDim>& original_multi_id;
|
||||
|
||||
__host__ __device__ constexpr lambda_0_GetOriginalMultiIndexFromMultiIndex(
|
||||
const Array<index_t, nDim>& multi_id_, Array<index_t, nOriginalDim>& original_multi_id_)
|
||||
: multi_id(multi_id_), original_multi_id(original_multi_id_)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr void operator()(Number<IDim>) const
|
||||
{
|
||||
constexpr auto original_dims_partial = std::get<IDim>(Type::mOriginalDimMergeSeqs);
|
||||
|
||||
// get partial original-multi-id corresponding to this merged dimension
|
||||
const auto original_multi_id_partial =
|
||||
OriginalTensorDesc::Extract(original_dims_partial)
|
||||
.GetMultiIndexFrom1dIndex(multi_id[IDim]);
|
||||
|
||||
static_for<0, original_dims_partial.GetSize(), 1>{}(
|
||||
lambda_1_GetOriginalMultiIndexFromMultiIndex<decltype(original_dims_partial)>(
|
||||
original_multi_id_partial, original_multi_id));
|
||||
}
|
||||
};
|
||||
|
||||
// return type is Array<...>
|
||||
__host__ __device__ static constexpr auto
|
||||
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
|
||||
{
|
||||
Array<index_t, nOriginalDim> original_multi_id;
|
||||
|
||||
static_for<0, nDim, 1>{}(
|
||||
lambda_0_GetOriginalMultiIndexFromMultiIndex(multi_id, original_multi_id));
|
||||
|
||||
return original_multi_id;
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence<Is...>)
|
||||
{
|
||||
constexpr auto multi_id = sequence2array(Sequence<Is...>{});
|
||||
|
||||
constexpr auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
|
||||
|
||||
return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id)
|
||||
{
|
||||
auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
|
||||
|
||||
return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
|
||||
}
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
|
||||
{
|
||||
return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
|
||||
{
|
||||
constexpr auto packed_desc = make_ConstantTensorDescriptor_packed(GetLengths());
|
||||
|
||||
return packed_desc.GetMultiIndexFrom1dIndex(id);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto Pack()
|
||||
{
|
||||
constexpr auto lengths = GetLengths();
|
||||
constexpr auto strides = calculate_tensor_strides_packed(lengths);
|
||||
return ConstantTensorDescriptor_deprecated<decltype(lengths), decltype(strides)>{};
|
||||
}
|
||||
};
|
||||
|
||||
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
|
||||
__host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc,
|
||||
OriginalDimMergeSeqs...)
|
||||
{
|
||||
return ConstantMergedTensorDescriptor_deprecated<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
|
||||
}
|
||||
|
||||
template <class TDesc>
|
||||
__host__ __device__ void print_ConstantMergedTensorDescriptor(const char* s, TDesc)
|
||||
{
|
||||
print_ConstantTensorDescriptor(s, TDesc::GetOriginalTensorDescriptor());
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,612 +0,0 @@
|
||||
#ifndef CK_CONSTANT_TENSOR_DESCRIPTOR_DEPRECATED_HPP
|
||||
#define CK_CONSTANT_TENSOR_DESCRIPTOR_DEPRECATED_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class Lengths>
|
||||
__host__ __device__ constexpr auto calculate_tensor_strides_packed_deprecated(Lengths)
|
||||
{
|
||||
return reverse_inclusive_scan_sequence(
|
||||
Lengths{}.PopFront(), math::multiplies<index_t>{}, Number<1>{})
|
||||
.PushBack(Number<1>{});
|
||||
}
|
||||
|
||||
template <class Lengths, index_t Align>
|
||||
__host__ __device__ constexpr auto calculate_tensor_strides_aligned_old(Lengths, Number<Align>)
|
||||
{
|
||||
constexpr index_t L_back_align =
|
||||
Align * math::integer_divide_ceiler<index_t>{}(Lengths{}.Back(), Align);
|
||||
|
||||
return calculate_tensor_strides_packed_deprecated(
|
||||
Lengths{}.Modify(Number<Lengths{}.GetSize() - 1>{}, Number<L_back_align>{}));
|
||||
}
|
||||
|
||||
template <class Lengths, class Strides>
|
||||
struct ConstantTensorDescriptor_deprecated
|
||||
{
|
||||
using Type = ConstantTensorDescriptor_deprecated;
|
||||
|
||||
static constexpr index_t nDim = Lengths::GetSize();
|
||||
|
||||
__host__ __device__ constexpr ConstantTensorDescriptor_deprecated()
|
||||
{
|
||||
static_assert(Lengths::GetSize() == Strides::GetSize(), "nDim not consistent");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetOriginalTensorDescriptor() { return Type{}; }
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
|
||||
{
|
||||
return Sequence<IDim>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths() { return Lengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetStrides() { return Strides{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLength(index_t IDim) { return Lengths{}[IDim]; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetStride(index_t IDim) { return Strides{}[IDim]; }
|
||||
|
||||
struct lambda_AreDimensionsContinuous
|
||||
{
|
||||
bool& is_continuous;
|
||||
|
||||
__host__ __device__ constexpr lambda_AreDimensionsContinuous(bool& is_continuous_)
|
||||
: is_continuous(is_continuous_)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t IDim_>
|
||||
__host__ __device__ constexpr void operator()(Number<IDim_>) const
|
||||
{
|
||||
constexpr auto IDim = Number<IDim_>{};
|
||||
constexpr auto IDim_p1 = Number<IDim_ + 1>{};
|
||||
|
||||
is_continuous =
|
||||
is_continuous && (GetStride(IDim) >= GetStride(IDim_p1) &&
|
||||
GetStride(IDim) == GetStride(IDim_p1) * GetLength(IDim_p1));
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr bool AreDimensionsContinuous()
|
||||
{
|
||||
bool is_continuous = true;
|
||||
|
||||
static_for<0, nDim - 1, 1>{}(lambda_AreDimensionsContinuous(is_continuous));
|
||||
|
||||
return is_continuous;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsPackedTensor()
|
||||
{
|
||||
return AreDimensionsContinuous() && GetStride(Number<nDim - 1>{}) == 1;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(T)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetElementSize()
|
||||
{
|
||||
return Number<reduce_on_sequence(Lengths{}, math::multiplies<index_t>{}, Number<1>{})>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetElementSpace()
|
||||
{
|
||||
constexpr index_t element_space_unaligned = reduce_on_sequence(
|
||||
(GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{});
|
||||
|
||||
return Number<element_space_unaligned>{};
|
||||
}
|
||||
|
||||
// emulate constexpr lambda
|
||||
template <index_t NSize>
|
||||
struct lambda_GetOffsetFromMultiIndex
|
||||
{
|
||||
Array<index_t, NSize>& multi_id;
|
||||
index_t& offset;
|
||||
|
||||
__host__
|
||||
__device__ constexpr lambda_GetOffsetFromMultiIndex(Array<index_t, NSize>& multi_id_,
|
||||
index_t& offset_)
|
||||
: multi_id(multi_id_), offset(offset_)
|
||||
{
|
||||
}
|
||||
|
||||
template <class X>
|
||||
__host__ __device__ constexpr void operator()(X IDim) const
|
||||
{
|
||||
offset += multi_id[IDim] * Type::GetStride(IDim);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NSize>
|
||||
__host__ __device__ static constexpr index_t
|
||||
GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id)
|
||||
{
|
||||
static_assert(NSize == nDim, "wrong! Dimension not consistent");
|
||||
|
||||
index_t offset = 0;
|
||||
|
||||
static_for<0, nDim, 1>{}(lambda_GetOffsetFromMultiIndex<NSize>(multi_id, offset));
|
||||
|
||||
return offset;
|
||||
}
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
|
||||
{
|
||||
return GetOffsetFromMultiIndex(Array<index_t, sizeof...(Is)>{is...});
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ static constexpr auto GetOffsetFromMultiIndex(Sequence<Is...>)
|
||||
{
|
||||
static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent");
|
||||
|
||||
constexpr auto multi_id = Sequence<Is...>{};
|
||||
|
||||
return Number<reduce_on_sequence(
|
||||
multi_id * GetStrides(), math::plus<index_t>{}, Number<0>{})>{};
|
||||
}
|
||||
|
||||
// emulate constexpr lambda
|
||||
template <class PackedStrides>
|
||||
struct lambda_GetMultiIndexFrom1dIndex
|
||||
{
|
||||
index_t& id;
|
||||
Array<index_t, nDim>& multi_id;
|
||||
|
||||
__host__
|
||||
__device__ constexpr lambda_GetMultiIndexFrom1dIndex(index_t& id_,
|
||||
Array<index_t, nDim>& multi_id_)
|
||||
: id(id_), multi_id(multi_id_)
|
||||
{
|
||||
}
|
||||
|
||||
template <class IDim_>
|
||||
__host__ __device__ constexpr void operator()(IDim_) const
|
||||
{
|
||||
constexpr auto IDim = IDim_{};
|
||||
constexpr index_t stride = PackedStrides::Get(IDim);
|
||||
multi_id(IDim) = id / stride;
|
||||
id -= multi_id[IDim] * stride;
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
|
||||
{
|
||||
Array<index_t, nDim> multi_id;
|
||||
|
||||
using PackedStrides = decltype(calculate_tensor_strides_packed_deprecated(GetLengths()));
|
||||
|
||||
// calculate index in each of the dimensions in the order of their dimension
|
||||
static_for<0, nDim - 1, 1>{}(lambda_GetMultiIndexFrom1dIndex<PackedStrides>(id, multi_id));
|
||||
|
||||
multi_id(Number<nDim - 1>{}) = id / PackedStrides::Get(Number<nDim - 1>{});
|
||||
|
||||
return multi_id;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
|
||||
{
|
||||
return multi_id;
|
||||
}
|
||||
|
||||
// This function doesn't do carry check on the highest dimension for positive stepping (or
|
||||
// borrow check on the highest dimension for negative stepping) , for performance reason. It is
|
||||
// the user's responsibility to make sure the result "new_mutli_id" is not out-of-bound on the
|
||||
// highest dimension for positive stepping (or on the lowest dimension for negative stepping)
|
||||
template <bool PositiveDirection>
|
||||
__host__ __device__ static Array<index_t, nDim>
|
||||
UpdateMultiIndexGivenStepSizeOf1dIndex(Array<index_t, nDim> old_multi_id,
|
||||
index_t step_size_of_1d_index,
|
||||
integral_constant<bool, PositiveDirection>)
|
||||
{
|
||||
Array<index_t, nDim> new_multi_id;
|
||||
|
||||
const auto step_sizes = GetMultiIndexFrom1dIndex(step_size_of_1d_index);
|
||||
|
||||
static_if<PositiveDirection>{}([&](auto) {
|
||||
new_multi_id = old_multi_id + step_sizes;
|
||||
|
||||
bool carry = false;
|
||||
|
||||
// do carry check in reversed order, starting from lowest dimension
|
||||
// don't check the highest dimension
|
||||
static_for<0, nDim, 1>{}([&](auto IDimReverse) {
|
||||
constexpr index_t idim = nDim - 1 - IDimReverse;
|
||||
constexpr auto IDim = Number<idim>{};
|
||||
|
||||
if(carry)
|
||||
{
|
||||
++new_multi_id(idim);
|
||||
}
|
||||
|
||||
carry = false;
|
||||
|
||||
if(new_multi_id[idim] >= GetLength(IDim))
|
||||
{
|
||||
new_multi_id(idim) -= GetLength(IDim);
|
||||
carry = true;
|
||||
}
|
||||
});
|
||||
}).Else([&](auto) {
|
||||
// shift up multi-id to avoid unsigned integer underflow during intermediate
|
||||
// calculations. After the shift, should have new_multi_id[...] >= 1
|
||||
new_multi_id = old_multi_id + (GetLengths() - step_sizes);
|
||||
|
||||
bool borrow = false;
|
||||
|
||||
// do borrow check in reversed order, starting from lowest dimension
|
||||
// don't check the highest dimension
|
||||
static_for<0, nDim, 1>{}([&](auto IDimReverse) {
|
||||
constexpr index_t idim = nDim - 1 - IDimReverse;
|
||||
constexpr auto IDim = Number<idim>{};
|
||||
|
||||
if(borrow)
|
||||
{
|
||||
--new_multi_id(idim);
|
||||
}
|
||||
|
||||
borrow = false;
|
||||
|
||||
if(new_multi_id[idim] < GetLength(IDim))
|
||||
{
|
||||
new_multi_id(idim) += GetLength(IDim);
|
||||
borrow = true;
|
||||
}
|
||||
});
|
||||
|
||||
// shift back down multi-id
|
||||
// here, should have new_multi_id[...] >= GetLengths()
|
||||
new_multi_id = new_multi_id - GetLengths();
|
||||
});
|
||||
|
||||
return new_multi_id;
|
||||
}
|
||||
|
||||
template <index_t... IDims>
|
||||
__host__ __device__ static constexpr auto Extract(Number<IDims>... extract_dims)
|
||||
{
|
||||
static_assert(sizeof...(IDims) <= GetNumOfDimension(),
|
||||
"wrong! too many number of dimensions to be extracted");
|
||||
|
||||
using extract_lengths = decltype(Lengths::Extract(extract_dims...));
|
||||
using extract_strides = decltype(Strides::Extract(extract_dims...));
|
||||
|
||||
return ConstantTensorDescriptor_deprecated<extract_lengths, extract_strides>{};
|
||||
}
|
||||
|
||||
template <index_t... IDims>
|
||||
__host__ __device__ static constexpr auto Extract(Sequence<IDims...>)
|
||||
{
|
||||
return Extract(Number<IDims>{}...);
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto Embed(ConstantTensorDescriptor_deprecated<Ts...>)
|
||||
{
|
||||
using leaf_tensor = ConstantTensorDescriptor_deprecated<Ts...>;
|
||||
|
||||
return ConstantTensorDescriptor_deprecated<
|
||||
decltype(GetLengths().PushBack(leaf_tensor::GetLengths())),
|
||||
decltype(GetStrides().PushBack(leaf_tensor::GetStrides()))>{};
|
||||
}
|
||||
|
||||
template <index_t IDimVector, index_t DataPerVector>
|
||||
struct lambda_IsVectorizationAllowed
|
||||
{
|
||||
bool& is_allowed;
|
||||
|
||||
__host__ __device__ constexpr lambda_IsVectorizationAllowed(bool& is_allowed_)
|
||||
: is_allowed(is_allowed_)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t IDim_>
|
||||
__host__ __device__ constexpr void operator()(Number<IDim_>) const
|
||||
{
|
||||
constexpr auto IDim = Number<IDim_>{};
|
||||
|
||||
if(IDimVector != IDim && Strides::Get(IDim) % DataPerVector != 0)
|
||||
{
|
||||
is_allowed = false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t IDimVector, index_t DataPerVector>
|
||||
__host__ __device__ static constexpr bool IsVectorizationAllowed(Number<IDimVector>,
|
||||
Number<DataPerVector>)
|
||||
{
|
||||
bool is_allowed = (Strides{}[IDimVector] == 1 || DataPerVector == 1) &&
|
||||
Lengths{}[IDimVector] % DataPerVector == 0;
|
||||
|
||||
static_for<0, nDim, 1>{}(
|
||||
lambda_IsVectorizationAllowed<IDimVector, DataPerVector>{is_allowed});
|
||||
|
||||
return is_allowed;
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t DataPerVector>
|
||||
__host__ __device__ static constexpr auto Vectorize(Number<IDim>, Number<DataPerVector>)
|
||||
{
|
||||
constexpr auto idim = Number<IDim>{};
|
||||
constexpr auto data_per_vector = Number<DataPerVector>{};
|
||||
|
||||
static_assert(IsVectorizationAllowed(idim, data_per_vector), "wrong!");
|
||||
|
||||
using vectorized_lengths =
|
||||
decltype(Lengths::Modify(Number<IDim>{}, Number<Lengths{}[IDim] / DataPerVector>{}));
|
||||
using vectorized_strides =
|
||||
decltype((Strides{} / Number<DataPerVector>{}).Modify(Number<IDim>{}, Number<1>{}));
|
||||
|
||||
return ConstantTensorDescriptor_deprecated<vectorized_lengths, vectorized_strides>{};
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t SliceLen>
|
||||
__host__ __device__ static constexpr auto Slice(Number<IDim>, Number<SliceLen>)
|
||||
{
|
||||
using slice_lengths = decltype(Lengths::Modify(Number<IDim>{}, Number<SliceLen>{}));
|
||||
|
||||
return ConstantTensorDescriptor_deprecated<slice_lengths, Strides>{};
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ static constexpr auto Slice(Sequence<Is...> slice_lengths)
|
||||
{
|
||||
static_assert(slice_lengths.GetSize() == nDim, "wrong!");
|
||||
|
||||
return ConstantTensorDescriptor_deprecated<decltype(slice_lengths), Strides>{};
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t SliceLength, index_t SliceStride>
|
||||
__host__ __device__ static constexpr auto
|
||||
StridedSlice(Number<IDim>, Number<SliceLength>, Number<SliceStride>)
|
||||
{
|
||||
constexpr index_t new_stride = Strides::Get(Number<IDim>{}) * SliceStride;
|
||||
|
||||
using new_lengths = decltype(Lengths::Modify(Number<IDim>{}, Number<SliceLength>{}));
|
||||
using new_strides = decltype(Strides::Modify(Number<IDim>{}, Number<new_stride>{}));
|
||||
|
||||
return ConstantTensorDescriptor_deprecated<new_lengths, new_strides>{};
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t... FoldIntervals>
|
||||
__host__ __device__ static constexpr auto Fold(Number<IDim>, Number<FoldIntervals>...)
|
||||
{
|
||||
constexpr auto fold_intervals = Sequence<FoldIntervals...>{};
|
||||
|
||||
constexpr index_t fold_intervals_product =
|
||||
reduce_on_sequence(fold_intervals, math::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
constexpr auto unfold_length = GetLength(Number<IDim>{});
|
||||
constexpr auto unfold_stride = GetStride(Number<IDim>{});
|
||||
|
||||
// length of the dimension to be folded needs to be dividable by fold_interval_product,
|
||||
// otherwise, folding is invalid
|
||||
static_assert(unfold_length % fold_intervals_product == 0,
|
||||
"wrong! length on the dimension to be folded cannot be evenly divided!");
|
||||
|
||||
// folded lengths
|
||||
constexpr auto fold_lengths =
|
||||
Sequence<unfold_length / fold_intervals_product>{}.PushBack(fold_intervals);
|
||||
|
||||
// folded strides
|
||||
constexpr auto fold_strides =
|
||||
Number<unfold_stride>{} *
|
||||
reverse_inclusive_scan_sequence(
|
||||
fold_intervals.PushBack(Number<1>{}), math::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
// left and right
|
||||
constexpr auto left = typename arithmetic_sequence_gen<0, IDim, 1>::type{};
|
||||
constexpr auto right =
|
||||
typename arithmetic_sequence_gen<IDim + 1, GetNumOfDimension(), 1>::type{};
|
||||
|
||||
constexpr auto new_lengths =
|
||||
GetLengths().Extract(left).PushBack(fold_lengths).PushBack(GetLengths().Extract(right));
|
||||
constexpr auto new_strides =
|
||||
GetStrides().Extract(left).PushBack(fold_strides).PushBack(GetStrides().Extract(right));
|
||||
|
||||
return ConstantTensorDescriptor_deprecated<decltype(new_lengths), decltype(new_strides)>{};
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t... FoldIntervals>
|
||||
__host__ __device__ static constexpr auto Fold(Number<IDim>, Sequence<FoldIntervals...>)
|
||||
{
|
||||
return Fold(Number<IDim>{}, Number<FoldIntervals>{}...);
|
||||
}
|
||||
|
||||
// this function unfold dimension [FirstUnfoldDim, ..., LastUnfoldDim] into 1 dimension
|
||||
template <index_t FirstUnfoldDim, index_t LastUnfoldDim>
|
||||
__host__ __device__ static constexpr auto Unfold(Number<FirstUnfoldDim>, Number<LastUnfoldDim>)
|
||||
{
|
||||
static_assert(FirstUnfoldDim >= 0 && LastUnfoldDim < nDim &&
|
||||
FirstUnfoldDim <= LastUnfoldDim,
|
||||
"wrong! should have FirstUnfoldDim <= LastUnfoldDim!");
|
||||
|
||||
// left and right
|
||||
constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::type{};
|
||||
constexpr auto middle =
|
||||
typename arithmetic_sequence_gen<FirstUnfoldDim, LastUnfoldDim + 1, 1>::type{};
|
||||
constexpr auto right =
|
||||
typename arithmetic_sequence_gen<LastUnfoldDim + 1, GetNumOfDimension(), 1>::type{};
|
||||
|
||||
// dimensions to be unfolded need to be continuous
|
||||
static_assert(Type::Extract(middle).AreDimensionsContinuous(), "wrong! not unfoldable");
|
||||
|
||||
// unfolded length, stride
|
||||
constexpr index_t unfold_length = reduce_on_sequence(
|
||||
GetLengths().Extract(middle), math::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
constexpr index_t unfold_stride = GetStride(Number<LastUnfoldDim>{});
|
||||
|
||||
// new lengths, strides
|
||||
constexpr auto new_lengths = GetLengths()
|
||||
.Extract(left)
|
||||
.PushBack(Number<unfold_length>{})
|
||||
.PushBack(GetLengths().Extract(right));
|
||||
|
||||
constexpr auto new_strides = GetStrides()
|
||||
.Extract(left)
|
||||
.PushBack(Number<unfold_stride>{})
|
||||
.PushBack(GetStrides().Extract(right));
|
||||
|
||||
return ConstantTensorDescriptor_deprecated<decltype(new_lengths), decltype(new_strides)>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto Pack()
|
||||
{
|
||||
using packed_strides = decltype(calculate_tensor_strides_packed_deprecated(Lengths{}));
|
||||
return ConstantTensorDescriptor_deprecated<Lengths, packed_strides>{};
|
||||
}
|
||||
|
||||
template <class MapNew2Old>
|
||||
__host__ __device__ static constexpr auto ReorderGivenNew2Old(MapNew2Old)
|
||||
{
|
||||
return ConstantTensorDescriptor_deprecated<
|
||||
decltype(Lengths::ReorderGivenNew2Old(MapNew2Old{})),
|
||||
decltype(Strides::ReorderGivenNew2Old(MapNew2Old{}))>{};
|
||||
}
|
||||
|
||||
template <class MapOld2New>
|
||||
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
|
||||
{
|
||||
return ConstantTensorDescriptor_deprecated<
|
||||
decltype(Lengths::ReorderGivenOld2New(MapOld2New{})),
|
||||
decltype(Strides::ReorderGivenOld2New(MapOld2New{}))>{};
|
||||
}
|
||||
};
|
||||
|
||||
template <class Lengths>
|
||||
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_packed(Lengths)
|
||||
{
|
||||
using Strides = decltype(calculate_tensor_strides_packed_deprecated(Lengths{}));
|
||||
return ConstantTensorDescriptor_deprecated<Lengths, Strides>{};
|
||||
}
|
||||
|
||||
template <class Lengths, class Strides>
|
||||
__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Strides)
|
||||
{
|
||||
return ConstantTensorDescriptor_deprecated<Lengths, Strides>{};
|
||||
}
|
||||
|
||||
template <class Lengths, index_t Align>
|
||||
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>)
|
||||
{
|
||||
using Strides = decltype(calculate_tensor_strides_aligned_old(Lengths{}, Number<Align>{}));
|
||||
return ConstantTensorDescriptor_deprecated<Lengths, Strides>{};
|
||||
}
|
||||
|
||||
template <index_t... Lengths, index_t... Strides>
|
||||
__host__ __device__ void print_ConstantTensorDescriptor(
|
||||
const char* s, ConstantTensorDescriptor_deprecated<Sequence<Lengths...>, Sequence<Strides...>>)
|
||||
{
|
||||
constexpr index_t ndim = sizeof...(Lengths);
|
||||
|
||||
static_assert(ndim > 0 && ndim <= 12, "wrong!");
|
||||
|
||||
static_if<ndim == 1>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u}, strides {%u}\n", s, ndim, Lengths..., Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 2>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n", s, ndim, Lengths..., Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 3>{}([&](auto) {
|
||||
printf(
|
||||
"%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n", s, ndim, Lengths..., Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 4>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
|
||||
s,
|
||||
ndim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 5>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n",
|
||||
s,
|
||||
ndim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 6>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n",
|
||||
s,
|
||||
ndim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 7>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
ndim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 8>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
ndim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 9>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
|
||||
"%u}\n",
|
||||
s,
|
||||
ndim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 10>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
|
||||
"%u %u %u}\n",
|
||||
s,
|
||||
ndim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 11>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u "
|
||||
"%u %u "
|
||||
"%u %u %u}\n",
|
||||
s,
|
||||
ndim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 12>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u "
|
||||
"%u %u %u %u "
|
||||
"%u %u %u}\n",
|
||||
s,
|
||||
ndim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,348 +0,0 @@
|
||||
#ifndef CK_TENSOR_COORDINATE_DEPRECATED_HPP
|
||||
#define CK_TENSOR_COORDINATE_DEPRECATED_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// TensorDesc is ConstantTensorDescriptor_deprecated
|
||||
template <class TensorDesc>
|
||||
struct NormalTensorCoordinate_deprecated
|
||||
{
|
||||
using type = NormalTensorCoordinate_deprecated;
|
||||
using tensor_desc_type = TensorDesc;
|
||||
|
||||
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
|
||||
|
||||
__host__
|
||||
__device__ constexpr NormalTensorCoordinate_deprecated(Array<index_t, nDim> tensor_index)
|
||||
: mOffset{tensor_desc_type::GetOffsetFromMultiIndex(tensor_index)}
|
||||
{
|
||||
}
|
||||
|
||||
template <class... Xs>
|
||||
__host__ __device__ constexpr NormalTensorCoordinate_deprecated(Xs... xs)
|
||||
: NormalTensorCoordinate_deprecated(Array<index_t, nDim>{xs...})
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t... Xs>
|
||||
__host__ __device__ constexpr NormalTensorCoordinate_deprecated(Sequence<Xs...>)
|
||||
: NormalTensorCoordinate_deprecated(Array<index_t, nDim>{Xs...})
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t GetOffset() const { return mOffset; }
|
||||
|
||||
// T is Array or Sequence
|
||||
template <class T>
|
||||
__host__ __device__ type operator+=(T step_sizes)
|
||||
{
|
||||
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
|
||||
|
||||
mOffset += tensor_desc_type::GetOffsetFromMultiIndex(step_sizes);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ type operator-=(T step_sizes)
|
||||
{
|
||||
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
|
||||
|
||||
mOffset -= tensor_desc_type::GetOffsetFromMultiIndex(step_sizes);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ constexpr type operator+(T step_sizes) const
|
||||
{
|
||||
type coord = *this;
|
||||
coord += step_sizes;
|
||||
return coord;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ constexpr type operator-(T step_sizes) const
|
||||
{
|
||||
type coord = *this;
|
||||
coord -= step_sizes;
|
||||
return coord;
|
||||
}
|
||||
|
||||
// reposition point of origin, and return compensated offset.
|
||||
// This is a hack to reduce index calculation during looping over
|
||||
// a tensor whose origin is this TensorCoordinate. It does so, by spitting
|
||||
// out the run-time offset to the pointer (to the tensor data) held by this
|
||||
// TensorCoordiante, so the caller can add the offset into the run-time pointer of
|
||||
// the data, so only 1 run-time variable (update pointer) is needed, instead
|
||||
// of 2 run-time variables (old pointer and this offset)
|
||||
// TODO: after introducing the concept of "run-time tensor view", which contains the
|
||||
// run-time pointer to the data, always keep track of the pointer, instead of both
|
||||
// offset and the pointer. This also bring additional benefit that we don't need to
|
||||
// worry the offset might underflow (because offset is unsigned integer) when updating it.
|
||||
__host__ __device__ constexpr index_t RepositionOrigin()
|
||||
{
|
||||
index_t offset_diff = mOffset;
|
||||
mOffset = 0;
|
||||
return offset_diff;
|
||||
}
|
||||
|
||||
private:
|
||||
index_t mOffset;
|
||||
};
|
||||
|
||||
// TensorDesc is ConstantMergedTensorDescriptor_deprecated
|
||||
template <class TensorDesc>
|
||||
struct MergedTensorCoordinate_deprecated
|
||||
{
|
||||
using type = MergedTensorCoordinate_deprecated;
|
||||
using tensor_desc_type = TensorDesc;
|
||||
|
||||
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
|
||||
static constexpr index_t nOriginalDim =
|
||||
tensor_desc_type::GetOriginalTensorDescriptor().GetNumOfDimension();
|
||||
|
||||
__host__
|
||||
__device__ constexpr MergedTensorCoordinate_deprecated(Array<index_t, nDim> tensor_index)
|
||||
: mOriginalIndex{tensor_desc_type::GetOriginalMultiIndexFromMultiIndex(tensor_index)}
|
||||
{
|
||||
// partial offset on each dimension
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
constexpr auto partial_original_dims =
|
||||
tensor_desc_type::GetContainedOriginalDimensions(idim);
|
||||
|
||||
constexpr auto partial_original_desc =
|
||||
tensor_desc_type::GetOriginalTensorDescriptor().Extract(partial_original_dims);
|
||||
|
||||
mPartialOffsets(idim) = partial_original_desc.GetOffsetFromMultiIndex(
|
||||
extract_array(mOriginalIndex, partial_original_dims));
|
||||
});
|
||||
|
||||
// complete offset
|
||||
mOffset =
|
||||
accumulate_on_array(mPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
|
||||
}
|
||||
|
||||
template <class... Xs>
|
||||
__host__ __device__ constexpr MergedTensorCoordinate_deprecated(Xs... xs)
|
||||
: MergedTensorCoordinate_deprecated(Array<index_t, nDim>{xs...})
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t GetOffset() const { return mOffset; }
|
||||
|
||||
template <class IDim, class T, bool PositiveDirection>
|
||||
__host__ __device__ void
|
||||
MoveOnDimension(IDim idim_, T step_size, integral_constant<bool, PositiveDirection>)
|
||||
{
|
||||
constexpr auto idim = idim_;
|
||||
|
||||
// if step_size is known at compile time
|
||||
static_if<is_static<T>::value>{}(
|
||||
[&](auto) { static_if<T{} == 0>{}([&](auto) { return; }); });
|
||||
|
||||
// update original index
|
||||
static_if<tensor_desc_type::ContainMultipleOriginalDimensions(idim)>{}([&](auto) {
|
||||
constexpr auto partial_original_dims =
|
||||
tensor_desc_type::GetContainedOriginalDimensions(idim);
|
||||
|
||||
constexpr index_t ndim_partial_original = partial_original_dims.GetSize();
|
||||
|
||||
constexpr auto partial_original_desc =
|
||||
tensor_desc_type::GetOriginalTensorDescriptor().Extract(partial_original_dims);
|
||||
|
||||
const auto partial_original_step_sizes =
|
||||
partial_original_desc.GetMultiIndexFrom1dIndex(step_size);
|
||||
|
||||
// update partial original multi-id
|
||||
auto partial_original_id = extract_array(mOriginalIndex, partial_original_dims);
|
||||
|
||||
static_if<PositiveDirection>{}([&](auto) {
|
||||
partial_original_id += partial_original_step_sizes;
|
||||
|
||||
bool carry = false;
|
||||
|
||||
// do carry check in reversed order, starting from lowest dimension
|
||||
// don't check the highest dimension
|
||||
static_for<0, ndim_partial_original - 1, 1>{}([&](auto IReverse) {
|
||||
constexpr index_t i = ndim_partial_original - 1 - IReverse;
|
||||
|
||||
if(carry)
|
||||
{
|
||||
++partial_original_id(i);
|
||||
}
|
||||
|
||||
carry = false;
|
||||
|
||||
if(partial_original_id[i] >= partial_original_desc.GetLength(i))
|
||||
{
|
||||
partial_original_id(i) -= partial_original_desc.GetLength(i);
|
||||
carry = true;
|
||||
}
|
||||
});
|
||||
|
||||
// highest dimension
|
||||
if(carry)
|
||||
{
|
||||
++partial_original_id(0);
|
||||
}
|
||||
}).Else([&](auto) {
|
||||
// shift up multi-id to avoid unsigned integer underflow during intermediate
|
||||
// calculations. After the shift, should have new_multi_id[...] >= 1
|
||||
partial_original_id +=
|
||||
partial_original_desc.GetLengths() - partial_original_step_sizes;
|
||||
|
||||
bool borrow = false;
|
||||
|
||||
// do borrow check in reversed order, starting from lowest dimension
|
||||
// don't check the highest dimension
|
||||
static_for<0, ndim_partial_original - 1, 1>{}([&](auto IReverse) {
|
||||
constexpr index_t i = ndim_partial_original - 1 - IReverse;
|
||||
|
||||
if(borrow)
|
||||
{
|
||||
--partial_original_id(i);
|
||||
}
|
||||
|
||||
borrow = false;
|
||||
|
||||
if(partial_original_id[i] < partial_original_desc.GetLength(i))
|
||||
{
|
||||
partial_original_id(i) += partial_original_desc.GetLength(i);
|
||||
borrow = true;
|
||||
}
|
||||
});
|
||||
|
||||
// highest dimension
|
||||
if(borrow)
|
||||
{
|
||||
--partial_original_id(0);
|
||||
}
|
||||
|
||||
// shift back down multi-id
|
||||
// here, should have new_multi_id[...] >= GetLengths()
|
||||
partial_original_id = partial_original_id - partial_original_desc.GetLengths();
|
||||
});
|
||||
|
||||
// update "mOriginalIndex"
|
||||
static_for<0, ndim_partial_original, 1>{}([&](auto I) {
|
||||
constexpr auto idim_original = partial_original_dims[I];
|
||||
|
||||
mOriginalIndex(idim_original) = partial_original_id[I];
|
||||
});
|
||||
|
||||
// calculate new partial offset on this merged dimension
|
||||
const index_t old_partial_offset = mPartialOffsets[idim];
|
||||
|
||||
mPartialOffsets(idim) =
|
||||
partial_original_desc.GetOffsetFromMultiIndex(partial_original_id);
|
||||
|
||||
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
|
||||
mOffset = (mOffset + mPartialOffsets[idim]) - old_partial_offset;
|
||||
}).Else([&](auto fwd) {
|
||||
static_if<PositiveDirection>{}([&](auto) {
|
||||
mOffset += step_size * fwd(tensor_desc_type{}).GetStride(idim);
|
||||
}).Else([&](auto) { mOffset -= step_size * fwd(tensor_desc_type{}).GetStride(idim); });
|
||||
});
|
||||
}
|
||||
|
||||
// T is Array or Sequence
|
||||
template <class T>
|
||||
__host__ __device__ type operator+=(T step_sizes)
|
||||
{
|
||||
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
// compiler should remove dead code path, because step_sizes is known at
|
||||
// compile time
|
||||
if(step_sizes[idim] != 0)
|
||||
{
|
||||
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, true>{});
|
||||
}
|
||||
});
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ type operator-=(T step_sizes)
|
||||
{
|
||||
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
// compiler should remove dead code path, because step_sizes is known at
|
||||
// compile time
|
||||
if(step_sizes[idim] != 0)
|
||||
{
|
||||
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, false>{});
|
||||
}
|
||||
});
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ constexpr type operator+(T step_sizes) const
|
||||
{
|
||||
type coord = *this;
|
||||
coord += step_sizes;
|
||||
return coord;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ constexpr type operator-(T step_sizes) const
|
||||
{
|
||||
type coord = *this;
|
||||
coord -= step_sizes;
|
||||
return coord;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t RepositionOrigin() { return 0; }
|
||||
|
||||
private:
|
||||
// Allocate register memory for all merged dimensions and normal dimensions.
|
||||
// However, only those merged dimensions, whose index will be involved in arithmetic
|
||||
// after the construction of this TensorCoordinate (e.g. when user move a slicing
|
||||
// window on the merged dimension), will use these register memory.
|
||||
// Let's hope compiler will optimize away those register memory allocated for normal
|
||||
// dimensions, and those merged dimensions, that would never be involved in index
|
||||
// arithmetic after construction of TensorCoordinate.
|
||||
// TODO: refactor TensorCoordinate, after introducing the concept of "dimensions"
|
||||
// and simplify implementation of ConstantMergedTensorDescriptor_deprecated, so we don't need to
|
||||
// count on compiler to optimize away those register memory for us
|
||||
Array<index_t, nOriginalDim> mOriginalIndex;
|
||||
Array<index_t, nDim> mPartialOffsets;
|
||||
|
||||
// complete offset
|
||||
index_t mOffset;
|
||||
};
|
||||
|
||||
template <class TensorDesc>
|
||||
struct TensorCoordinate_deprecated
|
||||
{
|
||||
private:
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(ConstantTensorDescriptor_deprecated<Ts...>)
|
||||
{
|
||||
return NormalTensorCoordinate_deprecated<ConstantTensorDescriptor_deprecated<Ts...>>();
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor_deprecated<Ts...>)
|
||||
{
|
||||
return MergedTensorCoordinate_deprecated<
|
||||
ConstantMergedTensorDescriptor_deprecated<Ts...>>();
|
||||
}
|
||||
|
||||
public:
|
||||
using type = decltype(MakeDummyTensorCoordinate(TensorDesc{}));
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,16 +0,0 @@
|
||||
#ifndef CK_TENSOR_COORDINATE_HELPER_HPP
|
||||
#define CK_TENSOR_COORDINATE_HELPER_HPP
|
||||
|
||||
#include "tensor_coordiante_hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename TensorDesc>
|
||||
__host__ __device__ constexpr auto
|
||||
make_tensor_coordinate(TensorDesc, MultiIndex<TensorDesc::GetNumOfDimension()> idx)
|
||||
{
|
||||
return typename TensorCoordinate<TensorDesc>::type(idx);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -18,11 +18,11 @@ template <index_t BlockSize,
|
||||
typename ThreadMatrixC,
|
||||
index_t MPerThreadSubC,
|
||||
index_t NPerThreadSubC,
|
||||
index_t KPerThreadLoop,
|
||||
index_t MLevel0ThreadCluster,
|
||||
index_t NLevel0ThreadCluster,
|
||||
index_t MLevel1ThreadCluster,
|
||||
index_t NLevel1ThreadCluster,
|
||||
index_t KPerThreadLoop,
|
||||
index_t ThreadGemmADataPerRead_M,
|
||||
index_t ThreadGemmBDataPerRead_N>
|
||||
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
|
||||
@@ -15,6 +15,8 @@ namespace ck {
|
||||
// The dimension access order can be different for src and dst.
|
||||
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
|
||||
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
|
||||
// BlockSize can be equal or larger than ThreadCluster size, which means some threads may not do
|
||||
// threadwise copy
|
||||
template <index_t BlockSize,
|
||||
typename BlockSrcDesc,
|
||||
typename BlockDstDesc,
|
||||
@@ -31,7 +33,9 @@ template <index_t BlockSize,
|
||||
AddressSpace SrcAddressSpace = AddressSpace::Generic,
|
||||
AddressSpace ThreadBufferAddressSpace = AddressSpace::Generic,
|
||||
AddressSpace DstAddressSpace = AddressSpace::Generic,
|
||||
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set>
|
||||
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set,
|
||||
index_t SrcDataStride = 1,
|
||||
index_t DstDataStride = 1>
|
||||
struct BlockwiseGenericTensorSliceCopy_v4
|
||||
{
|
||||
static constexpr index_t nDim = BlockSrcDesc::GetNumOfDimension();
|
||||
@@ -52,23 +56,23 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
// map threads to cluster
|
||||
constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
static_assert(BlockSize >= mThreadClusterDesc.GetElementSize(),
|
||||
"wrong! BlockSize too small");
|
||||
|
||||
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
|
||||
"wrong! BlockSize not consistent with ThreadClusterLengths");
|
||||
if(BlockSize == mThreadClusterDesc.GetElementSize() or
|
||||
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_id =
|
||||
mThreadClusterDesc.CalculateClusterIndex(get_thread_local_1d_id());
|
||||
|
||||
const auto thread_cluster_id =
|
||||
thread_cluster_desc.CalculateClusterIndex(get_thread_local_1d_id());
|
||||
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
|
||||
|
||||
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
|
||||
mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin);
|
||||
mThreadwiseLoad.SetDstSliceOrigin(make_zero_array<index_t, nDim>());
|
||||
|
||||
mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin);
|
||||
mThreadwiseLoad.SetDstSliceOrigin(make_zero_array<index_t, nDim>());
|
||||
|
||||
mThreadwiseStore.SetSrcSliceOrigin(make_zero_array<index_t, nDim>());
|
||||
mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin);
|
||||
mThreadwiseStore.SetSrcSliceOrigin(make_zero_array<index_t, nDim>());
|
||||
mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetThreadBufferSize()
|
||||
@@ -83,14 +87,18 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
constexpr bool has_optimized_address_calculation =
|
||||
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
|
||||
|
||||
// TODO: threadwise copy is still being tweaked
|
||||
if(has_optimized_address_calculation)
|
||||
if(BlockSize == mThreadClusterDesc.GetElementSize() or
|
||||
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
|
||||
{
|
||||
mThreadwiseLoad.Run_optimized_src_address_calculation(p_block_src, p_thread_buffer);
|
||||
}
|
||||
else
|
||||
{
|
||||
mThreadwiseLoad.Run(p_block_src, p_thread_buffer);
|
||||
// TODO: threadwise copy is still being tweaked
|
||||
if(has_optimized_address_calculation)
|
||||
{
|
||||
mThreadwiseLoad.Run_optimized_src_address_calculation(p_block_src, p_thread_buffer);
|
||||
}
|
||||
else
|
||||
{
|
||||
mThreadwiseLoad.Run(p_block_src, p_thread_buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,14 +109,19 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
constexpr bool has_optimized_address_calculation =
|
||||
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
|
||||
|
||||
// TODO: threadwise copy is still being tweaked
|
||||
if(has_optimized_address_calculation)
|
||||
if(BlockSize == mThreadClusterDesc.GetElementSize() or
|
||||
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
|
||||
{
|
||||
mThreadwiseStore.Run_optimized_dst_address_calculation(p_thread_buffer, p_block_dst);
|
||||
}
|
||||
else
|
||||
{
|
||||
mThreadwiseStore.Run(p_thread_buffer, p_block_dst);
|
||||
// TODO: threadwise copy is still being tweaked
|
||||
if(has_optimized_address_calculation)
|
||||
{
|
||||
mThreadwiseStore.Run_optimized_dst_address_calculation(p_thread_buffer,
|
||||
p_block_dst);
|
||||
}
|
||||
else
|
||||
{
|
||||
mThreadwiseStore.Run(p_thread_buffer, p_block_dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,10 +136,14 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
|
||||
BlockSrcData p_thread_buffer[GetThreadBufferSize()];
|
||||
|
||||
RunLoadThreadBuffer(p_block_src, p_thread_buffer);
|
||||
if(BlockSize == mThreadClusterDesc.GetElementSize() or
|
||||
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
|
||||
{
|
||||
RunLoadThreadBuffer(p_block_src, p_thread_buffer);
|
||||
|
||||
// if there is type conversion, it's done during store
|
||||
RunStoreThreadBuffer(p_thread_buffer, p_block_dst);
|
||||
// if there is type conversion, it's done during store
|
||||
RunStoreThreadBuffer(p_thread_buffer, p_block_dst);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool PositiveDirection>
|
||||
@@ -134,7 +151,11 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
MoveSrcSliceWindow(const T& step_sizes,
|
||||
integral_constant<bool, PositiveDirection> positive_direction)
|
||||
{
|
||||
mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction);
|
||||
if(BlockSize == mThreadClusterDesc.GetElementSize() or
|
||||
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
|
||||
{
|
||||
mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool PositiveDirection>
|
||||
@@ -142,7 +163,11 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
MoveDstSliceWindow(const T& step_sizes,
|
||||
integral_constant<bool, PositiveDirection> positive_direction)
|
||||
{
|
||||
mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction);
|
||||
if(BlockSize == mThreadClusterDesc.GetElementSize() or
|
||||
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
|
||||
{
|
||||
mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -157,7 +182,9 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
1,
|
||||
SrcAddressSpace,
|
||||
ThreadBufferAddressSpace,
|
||||
InMemoryDataOperation::Set>;
|
||||
InMemoryDataOperation::Set,
|
||||
SrcDataStride,
|
||||
1>;
|
||||
|
||||
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2<ThreadBufferDesc,
|
||||
BlockDstDesc,
|
||||
@@ -168,7 +195,12 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
DstDataPerWrite,
|
||||
ThreadBufferAddressSpace,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>;
|
||||
DstInMemOp,
|
||||
1,
|
||||
DstDataStride>;
|
||||
|
||||
static constexpr auto mThreadClusterDesc =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
ThreadwiseLoad mThreadwiseLoad;
|
||||
ThreadwiseStore mThreadwiseStore;
|
||||
|
||||
@@ -1,613 +0,0 @@
|
||||
#ifndef CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_DEPRECATED_HPP
|
||||
#define CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_DEPRECATED_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
#include "tensor_coordinate_deprecated.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
|
||||
// memory layout (ordering of dimensions) can be different between src and dst.
|
||||
// This functions assume each thread is reading and writing a normal (not merged) tensor,
|
||||
// to simplify index calculations. To satisfy this assumption, the user need to make sure
|
||||
// that, on a merged dimension that constains multiple original dimensions, the length of
|
||||
// the last original dimension need to be evenly dividable by its sub-lengths. Also, the
|
||||
// repeat-length on the merged dimension need to be 1. These sanity checks are performed
|
||||
// in constructor of BlockwiseGenericTensorSliceCopy_v1_deprecated
|
||||
template <index_t BlockSize,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename SubLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorAccessDim,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct BlockwiseGenericTensorSliceCopy_v1_deprecated
|
||||
{
|
||||
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
|
||||
|
||||
static constexpr index_t nOriginalDimSrc =
|
||||
SrcDesc::GetOriginalTensorDescriptor().GetNumOfDimension();
|
||||
static constexpr index_t nOriginalDimDst =
|
||||
DstDesc::GetOriginalTensorDescriptor().GetNumOfDimension();
|
||||
|
||||
// per-thread offset
|
||||
index_t mThreadSrcOffset;
|
||||
index_t mThreadDstOffset;
|
||||
|
||||
// "mThreadSrcOriginalMultiId", "mThreadSrcPartialOffsets, "mThreadDstOriginalMultiId",
|
||||
// "mThreadDstPartialOffsets" are always calculated inside constructor, and would be
|
||||
// updated if slicing-window is moved. However, they will not be used if you always move
|
||||
// the slicing-window along a non-merged dimension. In that case, compiler should be
|
||||
// able to remove these calculation.
|
||||
// TODO: make sure compiler would actually remove them in that case
|
||||
|
||||
// partial offset in each (merged) dimension
|
||||
Array<index_t, nDim> mThreadSrcPartialOffsets;
|
||||
Array<index_t, nDim> mThreadDstPartialOffsets;
|
||||
|
||||
// multi-id of original tensor
|
||||
Array<index_t, nOriginalDimSrc> mThreadSrcOriginalMultiId;
|
||||
Array<index_t, nOriginalDimDst> mThreadDstOriginalMultiId;
|
||||
|
||||
__device__
|
||||
BlockwiseGenericTensorSliceCopy_v1_deprecated(Array<index_t, nDim> src_block_data_id_begin,
|
||||
Array<index_t, nDim> dst_block_data_id_begin)
|
||||
{
|
||||
// check NDim consistency
|
||||
static_assert(
|
||||
nDim == SrcDesc::GetNumOfDimension() && nDim == DstDesc::GetNumOfDimension() &&
|
||||
nDim == SliceLengths::GetSize() && nDim == SubLengths::GetSize() &&
|
||||
nDim == ThreadClusterLengths::GetSize() &&
|
||||
nDim == ThreadClusterArrangeOrder::GetSize() &&
|
||||
nDim == SrcDimAccessOrder::GetSize() && nDim == DstDimAccessOrder::GetSize(),
|
||||
"wrong");
|
||||
|
||||
// check thread arrange order and read/write access order are valid
|
||||
static_assert(is_valid_sequence_map<ThreadClusterArrangeOrder>::value &&
|
||||
is_valid_sequence_map<SrcDimAccessOrder>::value &&
|
||||
is_valid_sequence_map<DstDimAccessOrder>::value,
|
||||
"wrong!");
|
||||
|
||||
// thread cluster
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
|
||||
ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
|
||||
|
||||
// BlockSize
|
||||
static_assert(BlockSize == thread_cluster_desc.GetElementSize(), "wrong! BlockSize");
|
||||
|
||||
// divide work
|
||||
constexpr auto data_per_cluster_per_dims = SubLengths{} * ThreadClusterLengths{};
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
static_assert(SliceLengths::Get(IDim) % data_per_cluster_per_dims.Get(IDim) == 0,
|
||||
"wrong! cannot evenly divide sliced tensor into cluster");
|
||||
});
|
||||
|
||||
constexpr auto repeat_lengths = SliceLengths{} / data_per_cluster_per_dims;
|
||||
|
||||
// additional check for merged dimension
|
||||
static_for<0, nDim, 1>{}([&](auto IDim_) {
|
||||
// src
|
||||
static_if<SrcDesc::ContainMultipleOriginalDimensions(IDim_)>{}([&](auto) {
|
||||
constexpr auto IDim = decltype(IDim_){};
|
||||
|
||||
// on a merged dimension that constains multiple original dimensions,
|
||||
// the length of the last original dimension need to evenly dividable by its
|
||||
// sub-length,
|
||||
// so each thread is effectively reading a normal (not merged) tensor
|
||||
constexpr auto idim_last_original_src =
|
||||
SrcDesc::GetContainedOriginalDimensions(IDim).Back();
|
||||
static_assert(
|
||||
SrcDesc::GetOriginalTensorDescriptor().GetLength(idim_last_original_src) %
|
||||
SubLengths::Get(IDim) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
// merged dimension should have repeat_lengths = 1
|
||||
static_assert(repeat_lengths[IDim] == 1,
|
||||
"wrong! repeat_lengths shoud be 1 on merged dimension");
|
||||
});
|
||||
|
||||
// dst
|
||||
static_if<DstDesc::ContainMultipleOriginalDimensions(IDim_)>{}([&](auto) {
|
||||
constexpr auto IDim = decltype(IDim_){};
|
||||
|
||||
// on a merged dimension that constains multiple original dimensions,
|
||||
// the length of the last original dimension need to evenly dividable by its
|
||||
// sub-length,
|
||||
// so each thread is effectively reading a normal (not merged) tensor
|
||||
constexpr auto idim_last_original_dst =
|
||||
DstDesc::GetContainedOriginalDimensions(IDim).Back();
|
||||
static_assert(
|
||||
DstDesc::GetOriginalTensorDescriptor().GetLength(idim_last_original_dst) %
|
||||
SubLengths::Get(IDim) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
// merged dimension should have repeat_lengths = 1
|
||||
static_assert(repeat_lengths[IDim] == 1,
|
||||
"wrong! repeat_lengths shoud be 1 on merged dimension");
|
||||
});
|
||||
});
|
||||
|
||||
// calculate mThreadSrcOffset, mThreadDstOffset
|
||||
const auto thread_cluster_id =
|
||||
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
|
||||
|
||||
const auto data_cluster_id =
|
||||
reorder_array_given_old2new(thread_cluster_id, ThreadClusterArrangeOrder{});
|
||||
|
||||
const auto thread_data_id_begin = data_cluster_id * SubLengths{};
|
||||
|
||||
// original multi-id
|
||||
mThreadSrcOriginalMultiId = SrcDesc::GetOriginalMultiIndexFromMultiIndex(
|
||||
src_block_data_id_begin + thread_data_id_begin);
|
||||
|
||||
mThreadDstOriginalMultiId = DstDesc::GetOriginalMultiIndexFromMultiIndex(
|
||||
dst_block_data_id_begin + thread_data_id_begin);
|
||||
|
||||
// partial offset on each dimension
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr auto src_partial_original_dims =
|
||||
SrcDesc::GetContainedOriginalDimensions(IDim);
|
||||
|
||||
constexpr auto src_partial_original_desc =
|
||||
SrcDesc::GetOriginalTensorDescriptor().Extract(src_partial_original_dims);
|
||||
|
||||
mThreadSrcPartialOffsets(IDim) = src_partial_original_desc.GetOffsetFromMultiIndex(
|
||||
extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims));
|
||||
});
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr auto dst_partial_original_dims =
|
||||
DstDesc::GetContainedOriginalDimensions(IDim);
|
||||
|
||||
constexpr auto dst_partial_original_desc =
|
||||
DstDesc::GetOriginalTensorDescriptor().Extract(dst_partial_original_dims);
|
||||
|
||||
mThreadDstPartialOffsets(IDim) = dst_partial_original_desc.GetOffsetFromMultiIndex(
|
||||
extract_array(mThreadDstOriginalMultiId, dst_partial_original_dims));
|
||||
});
|
||||
|
||||
// complete offset
|
||||
mThreadSrcOffset = accumulate_on_array(
|
||||
mThreadSrcPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
|
||||
|
||||
mThreadDstOffset = accumulate_on_array(
|
||||
mThreadDstPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetRegisterBufferDescriptor()
|
||||
{
|
||||
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{});
|
||||
|
||||
return make_ConstantTensorDescriptor_packed(SubLengths{} * repeat_lengths);
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetThreadBufferSize()
|
||||
{
|
||||
return GetRegisterBufferDescriptor().GetElementSpace();
|
||||
}
|
||||
|
||||
template <typename TData>
|
||||
__device__ void RunLoadThreadBuffer(const TData* __restrict__ p_src,
|
||||
TData* __restrict__ p_buffer) const
|
||||
{
|
||||
constexpr auto thread_sub_tensor_lengths = SubLengths{};
|
||||
|
||||
constexpr auto data_per_cluster_per_dims =
|
||||
thread_sub_tensor_lengths * ThreadClusterLengths{};
|
||||
|
||||
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{});
|
||||
|
||||
constexpr auto thread_buffer_desc = GetRegisterBufferDescriptor();
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_id) {
|
||||
constexpr auto src_thread_data_id_begin = repeat_id * data_per_cluster_per_dims;
|
||||
|
||||
constexpr auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths;
|
||||
|
||||
constexpr index_t src_offset =
|
||||
SrcDesc::GetOffsetFromMultiIndex(src_thread_data_id_begin);
|
||||
|
||||
constexpr index_t buffer_offset =
|
||||
thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin);
|
||||
#else
|
||||
ford<decltype(repeat_lengths)>{}([&](auto repeat_id) {
|
||||
const auto src_thread_data_id_begin = repeat_id * data_per_cluster_per_dims;
|
||||
|
||||
const auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths;
|
||||
|
||||
const index_t src_offset = SrcDesc::GetOffsetFromMultiIndex(src_thread_data_id_begin);
|
||||
|
||||
const index_t buffer_offset =
|
||||
thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin);
|
||||
#endif
|
||||
|
||||
// By position the origin of the per-thread window at the point, where multi-index
|
||||
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
|
||||
// is assuming each thread is copy a noraml (not merged) tensor.
|
||||
// To satisfy this assumption, the user need to make sure that, on a merged dimension
|
||||
// that constains multiple original dimensions, the length of the last original
|
||||
// dimension need to be evenly dividable by its sub-lengths. Also, the repeat-length on
|
||||
// the merged dimension need to be 1. These sanity checks are performed in constructor
|
||||
// of BlockwiseGenericTensorSliceCopy_v1_deprecated
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<SrcDesc,
|
||||
decltype(thread_buffer_desc),
|
||||
SubLengths,
|
||||
SrcDimAccessOrder,
|
||||
SrcVectorAccessDim,
|
||||
SrcDataPerAccess,
|
||||
1>(make_zero_array<index_t, nDim>(),
|
||||
make_zero_array<index_t, nDim>())
|
||||
.Run(p_src + src_offset + mThreadSrcOffset, p_buffer + buffer_offset);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename TData>
|
||||
__device__ void RunStoreThreadBuffer(const TData* __restrict__ p_buffer,
|
||||
TData* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto thread_sub_tensor_lengths = SubLengths{};
|
||||
|
||||
constexpr auto data_per_cluster_per_dims =
|
||||
thread_sub_tensor_lengths * ThreadClusterLengths{};
|
||||
|
||||
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{});
|
||||
|
||||
constexpr auto thread_buffer_desc = GetRegisterBufferDescriptor();
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_id) {
|
||||
constexpr auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths;
|
||||
|
||||
constexpr auto dst_data_id_begin = repeat_id * data_per_cluster_per_dims;
|
||||
|
||||
constexpr index_t buffer_offset =
|
||||
thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin);
|
||||
|
||||
constexpr index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_id_begin);
|
||||
#else
|
||||
ford<decltype(repeat_lengths)>{}([&](auto repeat_id) {
|
||||
const auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths;
|
||||
|
||||
const auto dst_data_id_begin = repeat_id * data_per_cluster_per_dims;
|
||||
|
||||
const index_t buffer_offset =
|
||||
thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin);
|
||||
|
||||
const index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_id_begin);
|
||||
#endif
|
||||
|
||||
// By position the origin of the per-thread window at the point, where multi-index
|
||||
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
|
||||
// is assuming each thread is copy a noraml (not merged) tensor.
|
||||
// To satisfy this assumption, the user need to make sure that, on a merged dimension
|
||||
// that constains multiple original dimensions, the length of the last original
|
||||
// dimension need to be evenly dividable by its sub-lengths. Also, the repeat-length on
|
||||
// the merged dimension need to be 1. These sanity checks are performed in constructor
|
||||
// of BlockwiseGenericTensorSliceCopy_v1_deprecated
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<decltype(thread_buffer_desc),
|
||||
DstDesc,
|
||||
SubLengths,
|
||||
DstDimAccessOrder,
|
||||
DstVectorAccessDim,
|
||||
1,
|
||||
DstDataPerAccess>(
|
||||
make_zero_array<index_t, nDim>(), make_zero_array<index_t, nDim>())
|
||||
.Run(p_buffer + buffer_offset, p_dst + dst_offset + mThreadDstOffset);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename TData>
|
||||
__device__ void Run(const TData* __restrict__ p_src, TData* __restrict__ p_dst) const
|
||||
{
|
||||
TData p_buffer[GetThreadBufferSize()];
|
||||
|
||||
RunLoadThreadBuffer(p_src, p_buffer);
|
||||
RunStoreThreadBuffer(p_buffer, p_dst);
|
||||
}
|
||||
|
||||
// When moving the slicing windows along a merged dimension, if the strides of the
|
||||
// contained (by the merged dimension) original dimensions are not in descending order,
|
||||
// then there is no guarantee that the new offset will be larger than the old offset
|
||||
// for movement in positive direction (vice versue for movement in negative direction).
|
||||
// As a result, there is the possiblity that the offset calculation may result in
|
||||
// unsigned integer underflow (due to "-" operation). However, this hazard should not
|
||||
// happen, as long as the users make sure the slicing window would not be moved out of
|
||||
// the boundary of the tensor being sliced. This functions doesn't do runtime sanity
|
||||
// check on out-of-bound slicing window, for performance reason
|
||||
template <index_t IDim_, index_t StepSize, bool PositiveDirection>
|
||||
__device__ void MoveSlicingWindowOnSourceTensor(
|
||||
Number<IDim_>, Number<StepSize>, integral_constant<bool, PositiveDirection> direction)
|
||||
{
|
||||
constexpr auto IDim = Number<IDim_>{};
|
||||
|
||||
static_if<SrcDesc::ContainMultipleOriginalDimensions(IDim)>{}([&](auto) {
|
||||
// logic for a merged dimension, also works for non-merged dimension, but its logic may
|
||||
// be unncessarily complicated for compiler to remove calculations that are useless for
|
||||
// a non-merged dimension
|
||||
|
||||
// extract partial original dimensions
|
||||
constexpr auto src_partial_original_dims =
|
||||
SrcDesc::GetContainedOriginalDimensions(IDim);
|
||||
|
||||
constexpr auto src_partial_original_desc =
|
||||
SrcDesc::GetOriginalTensorDescriptor().Extract(src_partial_original_dims);
|
||||
|
||||
// calculate new partial original multi-id
|
||||
auto old_src_partial_original_id =
|
||||
extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims);
|
||||
|
||||
auto new_src_partial_original_id =
|
||||
src_partial_original_desc.UpdateMultiIndexGivenStepSizeOf1dIndex(
|
||||
old_src_partial_original_id, StepSize, direction);
|
||||
|
||||
// update "mThreadSrcOriginalMultiId"
|
||||
static_for<0, decltype(src_partial_original_dims)::GetSize(), 1>{}([&](auto I) {
|
||||
constexpr auto IDimOriginal = src_partial_original_dims[I];
|
||||
|
||||
mThreadSrcOriginalMultiId(IDimOriginal) = new_src_partial_original_id[I];
|
||||
});
|
||||
|
||||
// calculate new partial offset on this merged dimension
|
||||
const index_t old_src_partial_offset = mThreadSrcPartialOffsets[IDim];
|
||||
|
||||
const index_t new_src_partial_offset =
|
||||
src_partial_original_desc.GetOffsetFromMultiIndex(new_src_partial_original_id);
|
||||
|
||||
// update "mThreadSrcPartialOffsets"
|
||||
mThreadSrcPartialOffsets(IDim) = new_src_partial_offset;
|
||||
|
||||
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
|
||||
mThreadSrcOffset = (mThreadSrcOffset + new_src_partial_offset) - old_src_partial_offset;
|
||||
}).Else([&](auto) {
|
||||
// Logic for non-merged dimension. If you are never going to move the slicing window on
|
||||
// a merged dimension, then "mThreadSrcOriginalMultiId" and "mThreadSrcPartialOffsets",
|
||||
// which are being calculated here, will never be used later. In this case, compiler
|
||||
// should be able to remove these calculations.
|
||||
// TODO: make sure compiler would actually remove them in this case.
|
||||
|
||||
// It is the user's responsiblity to make sure the slicing window will not be moved out
|
||||
// of the boundary of the tensor being sliced. Otherwise, there might be hazard like
|
||||
// unsigned integer underflow. That is NO runtime sanity check to prevent the hazard
|
||||
|
||||
constexpr auto IDimOriginal = SrcDesc::GetContainedOriginalDimensions(IDim).Front();
|
||||
|
||||
static_if<PositiveDirection>{}([&](auto fwd) {
|
||||
mThreadSrcOffset += StepSize * fwd(SrcDesc{}).GetStride(IDim);
|
||||
|
||||
mThreadSrcOriginalMultiId(IDimOriginal) += StepSize;
|
||||
|
||||
mThreadSrcPartialOffsets(IDim) += StepSize * fwd(SrcDesc{}).GetStride(IDim);
|
||||
}).Else([&](auto fwd) {
|
||||
mThreadSrcOffset -= StepSize * fwd(SrcDesc{}).GetStride(IDim);
|
||||
|
||||
mThreadSrcOriginalMultiId(IDimOriginal) -= StepSize;
|
||||
|
||||
mThreadSrcPartialOffsets(IDim) -= StepSize * fwd(SrcDesc{}).GetStride(IDim);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, bool PositiveDirection>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection> positive_direction)
|
||||
{
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
if(step_sizes[idim] != 0)
|
||||
{
|
||||
MoveSlicingWindowOnSourceTensor(idim, step_sizes[idim], positive_direction);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// This version use TensorCoordiante
|
||||
// Slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
|
||||
// memory layout (ordering of dimensions) can be different between src and dst.
|
||||
template <index_t BlockSize,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename SubLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorAccessDim,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct BlockwiseGenericTensorSliceCopy_v2_deprecated
|
||||
{
|
||||
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseGenericTensorSliceCopy_v2_deprecated(
|
||||
const Index& src_block_slice_origin, const Index& dst_block_slice_origin)
|
||||
{
|
||||
static_assert(
|
||||
nDim == SrcDesc::GetNumOfDimension() && nDim == DstDesc::GetNumOfDimension() &&
|
||||
nDim == SliceLengths::GetSize() && nDim == SubLengths::GetSize() &&
|
||||
nDim == ThreadClusterLengths::GetSize() &&
|
||||
nDim == ThreadClusterArrangeOrder::GetSize() &&
|
||||
nDim == SrcDimAccessOrder::GetSize() && nDim == DstDimAccessOrder::GetSize(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(is_same<SliceLengths, decltype(SubLengths{} * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
|
||||
ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
|
||||
|
||||
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
|
||||
"wrong! BlockSize not consistent with ThreadClusterLengths");
|
||||
|
||||
const auto thread_cluster_id =
|
||||
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
|
||||
|
||||
const auto data_cluster_id =
|
||||
reorder_array_given_old2new(thread_cluster_id, ThreadClusterArrangeOrder{});
|
||||
|
||||
const auto thread_data_id_begin = data_cluster_id * SubLengths{};
|
||||
|
||||
mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin);
|
||||
mThreadwiseLoad.SetDstSliceOrigin(make_zero_array<index_t, nDim>());
|
||||
|
||||
mThreadwiseStore.SetSrcSliceOrigin(make_zero_array<index_t, nDim>());
|
||||
mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin);
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetThreadBufferSize()
|
||||
{
|
||||
return ThreadBufferDesc::GetElementSpace();
|
||||
}
|
||||
|
||||
template <typename BlockSrcData,
|
||||
typename ThreadBufferData,
|
||||
AddressSpace BlockSrcAddressSpace,
|
||||
AddressSpace ThreadBufferAddressSpace>
|
||||
__device__ void
|
||||
RunLoadThreadBuffer(const BlockSrcData* p_block_src,
|
||||
ThreadBufferData* p_thread_buffer,
|
||||
integral_constant<AddressSpace, BlockSrcAddressSpace>,
|
||||
integral_constant<AddressSpace, ThreadBufferAddressSpace>) const
|
||||
{
|
||||
constexpr auto block_src_address_space =
|
||||
integral_constant<AddressSpace, BlockSrcAddressSpace>{};
|
||||
constexpr auto thread_buffer_address_space =
|
||||
integral_constant<AddressSpace, ThreadBufferAddressSpace>{};
|
||||
|
||||
mThreadwiseLoad.Run(
|
||||
p_block_src, p_thread_buffer, block_src_address_space, thread_buffer_address_space);
|
||||
}
|
||||
|
||||
template <typename BlockSrcData, typename ThreadBufferData>
|
||||
__device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src,
|
||||
ThreadBufferData* p_thread_buffer) const
|
||||
{
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::Generic>{};
|
||||
|
||||
RunLoadThreadBuffer(
|
||||
p_block_src, p_thread_buffer, generic_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
template <typename ThreadBufferData,
|
||||
typename BlockDstData,
|
||||
AddressSpace ThreadBufferAddressSpace,
|
||||
AddressSpace BlockDstAddressSpace>
|
||||
__device__ void
|
||||
RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer,
|
||||
BlockDstData* p_block_dst,
|
||||
integral_constant<AddressSpace, ThreadBufferAddressSpace>,
|
||||
integral_constant<AddressSpace, BlockDstAddressSpace>) const
|
||||
{
|
||||
constexpr auto thread_buffer_address_space =
|
||||
integral_constant<AddressSpace, ThreadBufferAddressSpace>{};
|
||||
constexpr auto block_dst_address_space =
|
||||
integral_constant<AddressSpace, BlockDstAddressSpace>{};
|
||||
|
||||
mThreadwiseStore.Run(
|
||||
p_thread_buffer, p_block_dst, thread_buffer_address_space, block_dst_address_space);
|
||||
}
|
||||
|
||||
template <typename ThreadBufferData, typename BlockDstData>
|
||||
__device__ void RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer,
|
||||
BlockDstData* p_block_dst) const
|
||||
{
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::Generic>{};
|
||||
|
||||
RunStoreThreadBuffer(
|
||||
p_thread_buffer, p_block_dst, generic_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
template <typename BlockSrcData,
|
||||
typename BlockDstData,
|
||||
AddressSpace BlockSrcAddressSpace,
|
||||
AddressSpace BlockDstAddressSpace>
|
||||
__device__ void
|
||||
Run(const BlockSrcData* p_block_src,
|
||||
BlockDstData* p_block_dst,
|
||||
integral_constant<AddressSpace, BlockSrcAddressSpace> block_src_address_space,
|
||||
integral_constant<AddressSpace, BlockDstAddressSpace> block_dst_address_space) const
|
||||
{
|
||||
BlockSrcData p_thread_buffer[GetThreadBufferSize()];
|
||||
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::Generic>{};
|
||||
|
||||
RunLoadThreadBuffer(
|
||||
p_block_src, p_thread_buffer, block_src_address_space, generic_address_space);
|
||||
|
||||
// if there is type conversion, it's done during store
|
||||
RunStoreThreadBuffer(
|
||||
p_thread_buffer, p_block_dst, generic_address_space, block_dst_address_space);
|
||||
}
|
||||
|
||||
template <typename BlockSrcData, typename BlockDstData>
|
||||
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const
|
||||
{
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::Generic>{};
|
||||
|
||||
Run(p_block_src, p_block_dst, generic_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
template <typename T, bool PositiveDirection>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection> positive_direction)
|
||||
{
|
||||
mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction);
|
||||
}
|
||||
|
||||
template <typename T, bool PositiveDirection>
|
||||
__device__ void
|
||||
MoveDstSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection> positive_direction)
|
||||
{
|
||||
mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction);
|
||||
}
|
||||
|
||||
private:
|
||||
using ThreadBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{}));
|
||||
|
||||
using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<SrcDesc,
|
||||
ThreadBufferDesc,
|
||||
SubLengths,
|
||||
SrcDimAccessOrder,
|
||||
SrcDimAccessOrder,
|
||||
SrcVectorAccessDim,
|
||||
SrcVectorAccessDim,
|
||||
SrcDataPerAccess,
|
||||
1>;
|
||||
|
||||
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<ThreadBufferDesc,
|
||||
DstDesc,
|
||||
SubLengths,
|
||||
DstDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
DstVectorAccessDim,
|
||||
DstVectorAccessDim,
|
||||
1,
|
||||
DstDataPerAccess>;
|
||||
|
||||
ThreadwiseLoad mThreadwiseLoad;
|
||||
ThreadwiseStore mThreadwiseStore;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -186,7 +186,6 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
|
||||
|
||||
constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
@@ -201,11 +200,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
decltype(c_m0m1_n0n1_thread_mtx_desc),
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
MLevel0Cluster,
|
||||
NLevel0Cluster,
|
||||
MLevel1Cluster,
|
||||
NLevel1Cluster,
|
||||
KPerThread,
|
||||
ThreadGemmAThreadCopySrcDataPerRead_M,
|
||||
ThreadGemmBThreadCopySrcDataPerRead_N>{};
|
||||
|
||||
|
||||
@@ -0,0 +1,330 @@
|
||||
#ifndef CK_GRIDWISE_TENSOR_CONTRACTION_HPP
|
||||
#define CK_GRIDWISE_TENSOR_CONTRACTION_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
typename CBlockLengths,
|
||||
index_t KPerBlock,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation>
|
||||
struct GridwiseTensorContraction_v1
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() {}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_a_global,
|
||||
const Float* __restrict__ p_b_global,
|
||||
Float* __restrict__ p_c_global,
|
||||
Float* __restrict__ p_shared_block) const
|
||||
{
|
||||
/// \todo sanity-check on AGlobalDesc, BGlboalDesc, CGlobalDesc length consisitency
|
||||
/// \todo santiy-check on CBlockLengtsh
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto a_global_desc = AGlobalDesc{};
|
||||
constexpr auto b_global_desc = BGlobalDesc{};
|
||||
constexpr auto c_global_desc = CGlobalDesc{};
|
||||
|
||||
constexpr auto K = a_global_desc.GetLengths()[0];
|
||||
|
||||
// don't do anything if K == 0
|
||||
if(K == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// lds max alignment
|
||||
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
ThreadGemmAThreadCopySrcDataPerRead_M,
|
||||
ThreadGemmBThreadCopySrcDataPerRead_N);
|
||||
|
||||
// divide block work by [M, N]
|
||||
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t MBlockWork = M / MPerBlock;
|
||||
constexpr index_t NBlockWork = N / NPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_cluster_descriptor(Sequence<MBlockWork, NBlockWork>{});
|
||||
|
||||
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
|
||||
|
||||
const index_t m_block_data_on_global = block_work_id[0] * MPerBlock;
|
||||
const index_t n_block_data_on_global = block_work_id[1] * NPerBlock;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
AGlobalDesc,
|
||||
decltype(a_block_desc),
|
||||
decltype(a_k_m_block_desc.GetLengths()),
|
||||
ABlockCopyThreadSliceLengths_K_M,
|
||||
ABlockCopyThreadClusterLengths_K_M,
|
||||
ABlockCopyThreadClusterArrangeOrder,
|
||||
ABlockCopySrcAccessOrder,
|
||||
Sequence<0, 1>,
|
||||
ABlockCopySrcVectorReadDim,
|
||||
1,
|
||||
ABlockCopySrcDataPerRead,
|
||||
ABlockCopyDstDataPerWrite_M,
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Lds,
|
||||
InMemoryDataOperation::Set>(
|
||||
{0, m_block_data_on_global}, {0, 0});
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(b_k_n_global_desc),
|
||||
decltype(b_k_n_block_desc),
|
||||
decltype(b_k_n_block_desc.GetLengths()),
|
||||
BBlockCopyThreadSliceLengths_K_N,
|
||||
BBlockCopyThreadClusterLengths_K_N,
|
||||
BBlockCopyThreadClusterArrangeOrder,
|
||||
BBlockCopySrcAccessOrder,
|
||||
Sequence<0, 1>,
|
||||
BBlockCopySrcVectorReadDim,
|
||||
1,
|
||||
BBlockCopySrcDataPerRead,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Lds,
|
||||
InMemoryDataOperation::Set>(
|
||||
{0, n_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[KPerBlocl, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(a_k_m_block_desc);
|
||||
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc);
|
||||
|
||||
// sanity check
|
||||
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
|
||||
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
|
||||
|
||||
constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * MPerThread>{}, Number<GemmNRepeat * NPerThread>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_k_m_block_mtx_desc),
|
||||
decltype(b_k_n_block_mtx_desc),
|
||||
decltype(c_m0m1_n0n1_thread_mtx_desc),
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
MLevel0Cluster,
|
||||
NLevel0Cluster,
|
||||
MLevel1Cluster,
|
||||
NLevel1Cluster,
|
||||
KPerThread,
|
||||
ThreadGemmAThreadCopySrcDataPerRead_M,
|
||||
ThreadGemmBThreadCopySrcDataPerRead_N>{};
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr index_t a_block_space =
|
||||
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
constexpr index_t b_block_space =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
Float* p_a_block_double = p_shared_block;
|
||||
Float* p_b_block_double = p_shared_block + 2 * a_block_space;
|
||||
|
||||
// register allocation for output
|
||||
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread);
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.Run(p_a_global, p_a_block_double);
|
||||
b_blockwise_copy.Run(p_b_global, p_b_block_double);
|
||||
}
|
||||
|
||||
constexpr auto a_block_slice_copy_steps = Sequence<KPerBlock, 0>{};
|
||||
constexpr auto b_block_slice_copy_steps = Sequence<KPerBlock, 0>{};
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
|
||||
k_block_data_begin += 2 * KPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_a_block_now =
|
||||
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
|
||||
Float* p_b_block_now =
|
||||
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
|
||||
|
||||
Float* p_a_block_next =
|
||||
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
|
||||
Float* p_b_block_next =
|
||||
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
|
||||
|
||||
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
|
||||
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
|
||||
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
|
||||
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
|
||||
|
||||
if(has_two_iteration_left) // if has 2 iteration left
|
||||
{
|
||||
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
|
||||
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
|
||||
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
|
||||
p_a_block_double + a_block_space);
|
||||
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
|
||||
p_b_block_double + b_block_space);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
p_a_block_double + a_block_space, p_b_block_double + b_block_space, p_c_thread);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
|
||||
}
|
||||
}
|
||||
|
||||
// input: register to global memory
|
||||
{
|
||||
constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t M0 = M / M1;
|
||||
|
||||
constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t N0 = N / N1;
|
||||
|
||||
// define input tensor descriptor for threadwise copy
|
||||
// thread input tensor, src of threadwise copy
|
||||
constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed(
|
||||
Sequence<GemmMRepeat, MPerThread, GemmNRepeat, NPerThread>{});
|
||||
|
||||
constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor(
|
||||
c_m_n_global_desc,
|
||||
make_tuple(UnMerge<Sequence<M0, M1>>{}, UnMerge<Sequence<N0, N1>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
// calculate origin of thread input tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t m_thread_data_on_global =
|
||||
m_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t n_thread_data_on_global =
|
||||
n_block_data_on_global + c_thread_mtx_on_block.col;
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_n0_n1_thread_desc),
|
||||
decltype(c_m0_m1_n0_n1_global_desc),
|
||||
decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()),
|
||||
CThreadCopySrcDstAccessOrder,
|
||||
CThreadCopySrcDstVectorReadWriteDim,
|
||||
1,
|
||||
CThreadCopyDstDataPerWrite,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Global,
|
||||
CGlobalMemoryDataOperation>(
|
||||
{0, 0, 0, 0},
|
||||
{m_thread_data_on_global / M1,
|
||||
m_thread_data_on_global % M1,
|
||||
n_thread_data_on_global / N1,
|
||||
n_thread_data_on_global % N1})
|
||||
.Run(p_c_thread, p_c_global);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_a_global,
|
||||
const Float* __restrict__ p_b_global,
|
||||
Float* __restrict__ p_c_global) const
|
||||
{
|
||||
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
|
||||
|
||||
__shared__ Float p_shared_block[shared_block_size];
|
||||
|
||||
Run(p_a_global, p_b_global, p_c_global, p_shared_block);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -23,7 +23,9 @@ template <typename SrcDesc,
|
||||
index_t DstDataPerWrite,
|
||||
AddressSpace SrcAddressSpace = AddressSpace::Generic,
|
||||
AddressSpace DstAddressSpace = AddressSpace::Generic,
|
||||
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set>
|
||||
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set,
|
||||
index_t SrcDataStride = 1,
|
||||
index_t DstDataStride = 1>
|
||||
struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
@@ -116,7 +118,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::Vgpr,
|
||||
InMemoryDataOperation::Set>(
|
||||
InMemoryDataOperation::Set,
|
||||
SrcDataStride,
|
||||
1>(
|
||||
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
|
||||
}
|
||||
}
|
||||
@@ -148,7 +152,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
DstDataPerWrite,
|
||||
AddressSpace::Vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(
|
||||
DstInMemOp,
|
||||
1,
|
||||
DstDataStride>(
|
||||
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,495 +0,0 @@
|
||||
#ifndef CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_DEPRECATED_HPP
|
||||
#define CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_DEPRECATED_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
#include "tensor_coordinate_deprecated.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// This threadwise copy allow vector access of src and dst.
|
||||
// It allows the vector size to be different on src and dst.
|
||||
// The dimensions of vector access should be the same on src and dst.
|
||||
// The dimension access order should be the same on src and dst.
|
||||
// It is designed for cases, where one of src and dst is register, and
|
||||
// the other is device memory or LDS
|
||||
template <typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t VectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct ThreadwiseGenericTensorSliceCopy_v1r2_deprecated
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::GetSize();
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1r2_deprecated(
|
||||
Array<index_t, nDim> src_slice_origin, Array<index_t, nDim> dst_slice_origin)
|
||||
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
|
||||
{
|
||||
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
|
||||
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
|
||||
nDim == DimAccessOrder::GetSize(),
|
||||
"wrong! # of dimensions not the same");
|
||||
|
||||
static_assert(is_valid_sequence_map<DimAccessOrder>::value, "wrong! map is not valid");
|
||||
|
||||
static_assert(
|
||||
SliceLengths{}[VectorAccessDim] % math::lcm(SrcDataPerAccess, DstDataPerAccess) == 0,
|
||||
"wrong! cannot evenly divide");
|
||||
|
||||
// check vectorized memory access
|
||||
constexpr auto vector_access_dim = Number<VectorAccessDim>{};
|
||||
|
||||
static_if<!SrcDesc::ContainMultipleOriginalDimensions(vector_access_dim)>{}([&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(SrcDesc{}).GetStride(vector_access_dim) == 1 || SrcDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
}).Else([&](auto fwd) {
|
||||
static_assert((fwd(SrcDesc{}).GetLastOriginalDimensionStride(vector_access_dim) == 1 ||
|
||||
SrcDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
});
|
||||
|
||||
static_if<!DstDesc::ContainMultipleOriginalDimensions(vector_access_dim)>{}([&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(DstDesc{}).GetStride(vector_access_dim) == 1 || DstDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
}).Else([&](auto fwd) {
|
||||
static_assert((fwd(DstDesc{}).GetLastOriginalDimensionStride(vector_access_dim) == 1 ||
|
||||
DstDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
});
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1r2_deprecated()
|
||||
: ThreadwiseGenericTensorSliceCopy_v1r2_deprecated(make_zero_array<index_t, nDim>(),
|
||||
make_zero_array<index_t, nDim>())
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(Array<index_t, nDim> src_slice_origin)
|
||||
{
|
||||
mSrcSliceOrigin = src_slice_origin;
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(Array<index_t, nDim> dst_slice_origin)
|
||||
{
|
||||
mDstSliceOrigin = dst_slice_origin;
|
||||
}
|
||||
|
||||
template <class SrcData, class DstData>
|
||||
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
|
||||
{
|
||||
using src_vector_t = typename vector_type<SrcData, SrcDataPerAccess>::MemoryType;
|
||||
using dst_vector_t = typename vector_type<DstData, DstDataPerAccess>::MemoryType;
|
||||
|
||||
constexpr auto vector_access_dim = Number<VectorAccessDim>{};
|
||||
|
||||
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
|
||||
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
|
||||
|
||||
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerAccess, DstDataPerAccess)>{};
|
||||
|
||||
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
|
||||
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
|
||||
|
||||
ford<decltype(long_vector_access_lengths), DimAccessOrder>{}(
|
||||
[&](auto long_vector_access_id) {
|
||||
|
||||
// data id w.r.t slicing-window
|
||||
auto long_vector_data_begin_id = long_vector_access_id;
|
||||
long_vector_data_begin_id(vector_access_dim) =
|
||||
long_vector_size * long_vector_access_id[vector_access_dim];
|
||||
|
||||
// buffer to hold a long-vector
|
||||
SrcData p_src_long_vector[long_vector_size];
|
||||
DstData p_dst_long_vector[long_vector_size];
|
||||
|
||||
// load data from src to the long-vector buffer
|
||||
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
|
||||
{
|
||||
auto scalar_id = make_zero_array<index_t, nDim>();
|
||||
scalar_id(vector_access_dim) = i * src_data_per_access;
|
||||
|
||||
const index_t src_offset = SrcDesc::GetOffsetFromMultiIndex(
|
||||
mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id));
|
||||
|
||||
const index_t buffer_offset = i * src_data_per_access;
|
||||
|
||||
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
|
||||
*reinterpret_cast<const src_vector_t*>(&p_src[src_offset]);
|
||||
}
|
||||
|
||||
// type conversion
|
||||
for(index_t i = 0; i < long_vector_size; ++i)
|
||||
{
|
||||
p_dst_long_vector[i] = type_convert<DstData>{}(p_src_long_vector[i]);
|
||||
}
|
||||
|
||||
// store data from the long-vector buffer to dst
|
||||
for(index_t i = 0; i < long_vector_size / dst_data_per_access; ++i)
|
||||
{
|
||||
auto scalar_id = make_zero_array<index_t, nDim>();
|
||||
scalar_id(vector_access_dim) = i * dst_data_per_access;
|
||||
|
||||
const index_t buffer_offset = i * dst_data_per_access;
|
||||
|
||||
const index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(
|
||||
mDstSliceOrigin + (long_vector_data_begin_id + scalar_id));
|
||||
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst[dst_offset]) =
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
Array<index_t, nDim> mSrcSliceOrigin;
|
||||
Array<index_t, nDim> mDstSliceOrigin;
|
||||
};
|
||||
|
||||
// This version use TensorCoordinate_deprecated
|
||||
// This threadwise copy allow vector access of src and dst.
|
||||
// It allows the dimensions of vector access to be different on src and dst.
|
||||
// It also allows the vector size to be different on src and dst.
|
||||
// It also allows order of access to be different on src and dst.
|
||||
// It use register as buffer to hold all data moving from src to dst.
|
||||
// It is designed for copying small amount of data, and src and dst are
|
||||
// device memory or LDS.
|
||||
// When copying large amout of data, let's hope compiler will reduce register
|
||||
// used for the buffer.
|
||||
template <typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorAccessDim,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::GetSize();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoordinate = typename TensorCoordinate_deprecated<SrcDesc>::type;
|
||||
using DstCoordinate = typename TensorCoordinate_deprecated<DstDesc>::type;
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2r1_deprecated(
|
||||
const Index& src_slice_origin, const Index& dst_slice_origin)
|
||||
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
|
||||
{
|
||||
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
|
||||
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
|
||||
nDim == SrcDimAccessOrder::GetSize() &&
|
||||
nDim == DstDimAccessOrder::GetSize(),
|
||||
"wrong! # of dimensions not the same");
|
||||
|
||||
static_assert(is_valid_sequence_map<SrcDimAccessOrder>::value &&
|
||||
is_valid_sequence_map<DstDimAccessOrder>::value,
|
||||
"wrong! map is not valid");
|
||||
|
||||
static_assert(SliceLengths{}[SrcVectorAccessDim] % SrcDataPerAccess == 0 &&
|
||||
SliceLengths{}[DstVectorAccessDim] % DstDataPerAccess == 0,
|
||||
"wrong! cannot evenly divide");
|
||||
|
||||
// check vectorized memory access
|
||||
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{};
|
||||
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{};
|
||||
|
||||
static_if<!SrcDesc::ContainMultipleOriginalDimensions(src_vector_access_dim)>{}(
|
||||
[&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(SrcDesc{}).GetStride(src_vector_access_dim) == 1 || SrcDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
})
|
||||
.Else([&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(SrcDesc{}).GetLastOriginalDimensionStride(src_vector_access_dim) == 1 ||
|
||||
SrcDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
});
|
||||
|
||||
static_if<!DstDesc::ContainMultipleOriginalDimensions(dst_vector_access_dim)>{}(
|
||||
[&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(DstDesc{}).GetStride(dst_vector_access_dim) == 1 || DstDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
})
|
||||
.Else([&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(DstDesc{}).GetLastOriginalDimensionStride(dst_vector_access_dim) == 1 ||
|
||||
DstDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
});
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2r1_deprecated()
|
||||
: ThreadwiseGenericTensorSliceCopy_v2r1_deprecated(make_zero_array<index_t, nDim>(),
|
||||
make_zero_array<index_t, nDim>())
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(SrcCoordinate src_slice_origin)
|
||||
{
|
||||
mSrcSliceOrigin = src_slice_origin;
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(DstCoordinate dst_slice_origin)
|
||||
{
|
||||
mDstSliceOrigin = dst_slice_origin;
|
||||
}
|
||||
|
||||
template <typename TDesc, class Lengths>
|
||||
struct IsolateMergedDimLengths
|
||||
{
|
||||
template <typename IDim>
|
||||
__device__ constexpr index_t operator()(IDim idim) const
|
||||
{
|
||||
return TDesc::ContainMultipleOriginalDimensions(idim) ? Lengths{}[idim] : 1;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace>
|
||||
__device__ void Run(const SrcData* p_src,
|
||||
DstData* p_dst,
|
||||
integral_constant<AddressSpace, SrcAddressSpace>,
|
||||
integral_constant<AddressSpace, DstAddressSpace>) const
|
||||
{
|
||||
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
|
||||
|
||||
SrcData p_src_buffer_[buffer_desc.GetElementSpace()];
|
||||
SrcData* p_src_buffer = p_src_buffer_;
|
||||
|
||||
// copy data from src into buffer
|
||||
{
|
||||
using src_vector_t = typename vector_type<SrcData, SrcDataPerAccess>::MemoryType;
|
||||
|
||||
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{};
|
||||
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths::Modify(
|
||||
src_vector_access_dim,
|
||||
SliceLengths::Get(src_vector_access_dim) / src_data_per_access);
|
||||
|
||||
// Offset w.r.t merged dimensions need to be calculated at run-time. Offset w.r.t
|
||||
// normal dimensions is known at compile time.
|
||||
// Below is a hack to isolate merged dimension id from normal dimension id, so the
|
||||
// corresponding offset can be calculated seperately at run-time and compile-time.
|
||||
// src_merged_dim_access_lengths has the same value as src_access_lengths on src's
|
||||
// merged dimensions, and has value = 1 on normal dimensions;
|
||||
// src_merged_dim_access_lengths has the same value as src_access_lengths on src's
|
||||
// normal dimensions, and has value = 1 on merged dimensions;
|
||||
constexpr auto src_merged_dim_access_lengths = typename sequence_gen<
|
||||
nDim,
|
||||
IsolateMergedDimLengths<SrcDesc, decltype(src_access_lengths)>>::type{};
|
||||
|
||||
constexpr auto src_normal_dim_access_lengths =
|
||||
src_access_lengths + Number<1>{} - src_merged_dim_access_lengths;
|
||||
|
||||
ford<decltype(src_merged_dim_access_lengths), SrcDimAccessOrder>{}(
|
||||
[&](auto src_merged_dim_access_id) {
|
||||
|
||||
auto src_merged_dim_data_id = src_merged_dim_access_id;
|
||||
src_merged_dim_data_id(src_vector_access_dim) =
|
||||
src_merged_dim_access_id[src_vector_access_dim] * src_data_per_access;
|
||||
|
||||
// offset w.r.t. merged dimension need be computed at run-time,
|
||||
const index_t src_merged_offset =
|
||||
(mSrcSliceOrigin + src_merged_dim_data_id).GetOffset();
|
||||
|
||||
ford<decltype(src_normal_dim_access_lengths), SrcDimAccessOrder>{}([&](
|
||||
auto src_normal_dim_access_id) {
|
||||
|
||||
auto src_normal_dim_data_id = src_normal_dim_access_id;
|
||||
src_normal_dim_data_id(src_vector_access_dim) =
|
||||
src_normal_dim_access_id[src_vector_access_dim] * src_data_per_access;
|
||||
|
||||
// offset w.r.t. normal dimension is known at compile-time
|
||||
const index_t src_normal_offset =
|
||||
SrcDesc::GetOffsetFromMultiIndex(src_normal_dim_data_id);
|
||||
|
||||
src_vector_t vector_data;
|
||||
|
||||
// Read vector from src.
|
||||
// 1. Source code version can take src of all kinds of memory-space
|
||||
// 2. Intrinsic version using buffer_load can only take
|
||||
// src from global-memory
|
||||
//
|
||||
// Commemt for loading from global-memory:
|
||||
// When:
|
||||
// 1) using source code, in order for compiler to emit optimal
|
||||
// load instruction, or
|
||||
// 2) using buffer_load intrinsic, in order for ISA to be valid,
|
||||
// following assumptions need to be satisfied:
|
||||
// 1. p_src need to be block-invariant (assumption)
|
||||
// 2. src_normal_offset must be calculatd at compile time (guaranteed by
|
||||
// algorithm)
|
||||
// 3. src_merged_offset can be runtime value (no assumption imposed)
|
||||
static_if<SrcAddressSpace == AddressSpace::Global>{}([&](auto fwd) {
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
vector_data = amd_intrinsic_buffer_load<SrcData, SrcDataPerAccess>(
|
||||
fwd(p_src), src_merged_offset, src_normal_offset);
|
||||
#else
|
||||
vector_data = *reinterpret_cast<const src_vector_t*>(
|
||||
&p_src[src_normal_offset + src_merged_offset]);
|
||||
#endif
|
||||
}).Else([&](auto) {
|
||||
// src can be all kinds of memory-space.
|
||||
vector_data = *reinterpret_cast<const src_vector_t*>(
|
||||
&p_src[src_normal_offset + src_merged_offset]);
|
||||
});
|
||||
|
||||
// unpack vector into buffer
|
||||
for(index_t i = 0; i < SrcDataPerAccess; ++i)
|
||||
{
|
||||
auto scalar_id = make_zero_array<index_t, nDim>();
|
||||
scalar_id(src_vector_access_dim) = i;
|
||||
|
||||
const index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex(
|
||||
src_merged_dim_data_id + src_normal_dim_data_id + scalar_id);
|
||||
|
||||
p_src_buffer[buffer_offset] =
|
||||
reinterpret_cast<const SrcData*>(&vector_data)[i];
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// type conversion
|
||||
// TODO: would compiler do a good job reusing register for buffer?
|
||||
DstData p_dst_buffer_[buffer_desc.GetElementSpace()];
|
||||
DstData* p_dst_buffer = p_dst_buffer_;
|
||||
|
||||
ford<SliceLengths>{}([&](auto idx) {
|
||||
p_dst_buffer[buffer_desc.GetOffsetFromMultiIndex(idx)] =
|
||||
type_convert<DstData>{}(p_src_buffer[buffer_desc.GetOffsetFromMultiIndex(idx)]);
|
||||
});
|
||||
|
||||
// copy data from buffer into dst
|
||||
{
|
||||
using dst_vector_t = typename vector_type<DstData, DstDataPerAccess>::MemoryType;
|
||||
|
||||
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{};
|
||||
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths::Modify(
|
||||
dst_vector_access_dim,
|
||||
SliceLengths::Get(dst_vector_access_dim) / dst_data_per_access);
|
||||
|
||||
constexpr auto dst_merged_dim_access_lengths = typename sequence_gen<
|
||||
nDim,
|
||||
IsolateMergedDimLengths<DstDesc, decltype(dst_access_lengths)>>::type{};
|
||||
|
||||
constexpr auto dst_normal_dim_access_lengths =
|
||||
dst_access_lengths + Number<1>{} - dst_merged_dim_access_lengths;
|
||||
|
||||
ford<decltype(dst_merged_dim_access_lengths), DstDimAccessOrder>{}([&](
|
||||
auto dst_merged_dim_access_id) {
|
||||
|
||||
auto dst_merged_dim_data_id = dst_merged_dim_access_id;
|
||||
dst_merged_dim_data_id(dst_vector_access_dim) =
|
||||
dst_merged_dim_access_id[dst_vector_access_dim] * dst_data_per_access;
|
||||
|
||||
// offset w.r.t. merged dimension need be computed at run-time,
|
||||
const index_t dst_merged_offset =
|
||||
(mDstSliceOrigin + dst_merged_dim_data_id).GetOffset();
|
||||
|
||||
ford<decltype(dst_normal_dim_access_lengths), DstDimAccessOrder>{}([&](
|
||||
auto dst_normal_dim_access_id) {
|
||||
|
||||
auto dst_normal_dim_data_id = dst_normal_dim_access_id;
|
||||
dst_normal_dim_data_id(dst_vector_access_dim) =
|
||||
dst_normal_dim_access_id[dst_vector_access_dim] * dst_data_per_access;
|
||||
|
||||
dst_vector_t vector_data;
|
||||
|
||||
// pack vector from buffer
|
||||
for(index_t i = 0; i < DstDataPerAccess; ++i)
|
||||
{
|
||||
auto scalar_id = make_zero_array<index_t, nDim>();
|
||||
scalar_id(dst_vector_access_dim) = i;
|
||||
|
||||
const index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex(
|
||||
dst_merged_dim_data_id + dst_normal_dim_data_id + scalar_id);
|
||||
|
||||
reinterpret_cast<DstData*>(&vector_data)[i] = p_dst_buffer[buffer_offset];
|
||||
}
|
||||
|
||||
// offset w.r.t. normal dimension is known at compile-time
|
||||
const index_t dst_normal_offset =
|
||||
DstDesc::GetOffsetFromMultiIndex(dst_normal_dim_data_id);
|
||||
|
||||
// Write vector into dst.
|
||||
// 1. Source code version can take dst of all kinds of memory-space
|
||||
// 2. Intrinsic version using buffer_store can only take
|
||||
// dst from global-memory
|
||||
//
|
||||
// Commemt for storing into global-memory:
|
||||
// When:
|
||||
// 1) using source code, in order for compiler to emit optimal
|
||||
// store instruction, or
|
||||
// 2) using buffer_store, intrinsic in order ISA to be valid
|
||||
// following assumptions need to be satisfied:
|
||||
// 1. p_dst need to be block-invariant (assumption)
|
||||
// 2. dst_normal_offset must be calculatd at compile time (guaranteed by
|
||||
// algorithm)
|
||||
// 3. dst_merged_offset can be runtime value (no assumption imposed)
|
||||
static_if<DstAddressSpace == AddressSpace::Global>{}([&](auto fwd) {
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
amd_intrinsic_buffer_store<DstData, DstDataPerAccess>(
|
||||
vector_data, fwd(p_dst), dst_merged_offset, dst_normal_offset);
|
||||
#else
|
||||
*reinterpret_cast<dst_vector_t*>(
|
||||
&p_dst[dst_normal_offset + dst_merged_offset]) = vector_data;
|
||||
#endif
|
||||
}).Else([&](auto) {
|
||||
// dst can be all kinds of memory-space
|
||||
*reinterpret_cast<dst_vector_t*>(
|
||||
&p_dst[dst_normal_offset + dst_merged_offset]) = vector_data;
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcData, typename DstData>
|
||||
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
|
||||
{
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::Generic>{};
|
||||
|
||||
Run(p_src, p_dst, generic_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
// T can be Sequence or Array
|
||||
template <typename T, bool PositiveDirection>
|
||||
__device__ void MoveSrcSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
|
||||
{
|
||||
static_if<PositiveDirection>{}([&](auto) {
|
||||
mSrcSliceOrigin += step_sizes;
|
||||
}).Else([&](auto) { mSrcSliceOrigin -= step_sizes; });
|
||||
}
|
||||
|
||||
template <typename T, bool PositiveDirection>
|
||||
__device__ void MoveDstSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
|
||||
{
|
||||
static_if<PositiveDirection>{}([&](auto) {
|
||||
mDstSliceOrigin += step_sizes;
|
||||
}).Else([&](auto) { mDstSliceOrigin -= step_sizes; });
|
||||
}
|
||||
|
||||
private:
|
||||
SrcCoordinate mSrcSliceOrigin;
|
||||
DstCoordinate mDstSliceOrigin;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -8,65 +8,149 @@ namespace ck {
|
||||
// For 128bit SGPRs in buffer_load and buffer_store instructions
|
||||
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
|
||||
template <typename T>
|
||||
union BufferLoadStoreDwordConfig
|
||||
union BufferAddressConfig
|
||||
{
|
||||
int32x4_t data;
|
||||
T* address[2];
|
||||
int32_t range[4];
|
||||
};
|
||||
|
||||
__device__ float __llvm_amdgcn_buffer_load(int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.load.f32");
|
||||
__device__ float __llvm_amdgcn_buffer_load_f32(int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.load.f32");
|
||||
|
||||
__device__ float2_t __llvm_amdgcn_buffer_loadx2(int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.load.v2f32");
|
||||
|
||||
__device__ float4_t __llvm_amdgcn_buffer_loadx4(int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.load.v4f32");
|
||||
|
||||
__device__ void __llvm_amdgcn_buffer_store(float vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store.f32");
|
||||
|
||||
__device__ void __llvm_amdgcn_buffer_storex2(float2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store.v2f32");
|
||||
|
||||
__device__ void __llvm_amdgcn_buffer_storex4(float4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store.v4f32");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_buffer_atomic_add(float vdata,
|
||||
int32x4_t rsrc,
|
||||
__device__ float2_t
|
||||
__llvm_amdgcn_buffer_load_f32x2(int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.atomic.fadd.f32");
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.load.v2f32");
|
||||
|
||||
__device__ float4_t
|
||||
__llvm_amdgcn_buffer_load_f32x4(int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.load.v4f32");
|
||||
|
||||
__device__ half_t __llvm_amdgcn_buffer_load_f16(int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.load.f16");
|
||||
|
||||
__device__ half2_t __llvm_amdgcn_buffer_load_f16x2(int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.load.v2f16");
|
||||
|
||||
__device__ half4_t __llvm_amdgcn_buffer_load_f16x4(int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.load.v4f16");
|
||||
|
||||
__device__ ushort __llvm_amdgcn_buffer_load_bf16(int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.load.bf16");
|
||||
|
||||
__device__ ushort2_t
|
||||
__llvm_amdgcn_buffer_load_bf16x2(int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.load.v2bf16");
|
||||
|
||||
__device__ ushort4_t
|
||||
__llvm_amdgcn_buffer_load_bf16x4(int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.load.v4bf16");
|
||||
|
||||
__device__ void __llvm_amdgcn_buffer_store_f32(float vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store.f32");
|
||||
|
||||
__device__ void __llvm_amdgcn_buffer_store_f32x2(float2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store.v2f32");
|
||||
|
||||
__device__ void __llvm_amdgcn_buffer_store_f32x4(float4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store.v4f32");
|
||||
|
||||
__device__ void __llvm_amdgcn_buffer_store_f16(half_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store.f16");
|
||||
|
||||
__device__ void __llvm_amdgcn_buffer_store_f16x2(half2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store.v2f16");
|
||||
|
||||
__device__ void __llvm_amdgcn_buffer_store_f16x4(half4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store.v4f16");
|
||||
|
||||
__device__ void __llvm_amdgcn_buffer_store_bf16(ushort vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store.bf16");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_buffer_store_bf16x2(ushort2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store.v2bf16");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_buffer_store_bf16x4(ushort4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store.v4bf16");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_buffer_atomic_add_f32(float vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.atomic.fadd.f32");
|
||||
|
||||
// buffer_load requires:
|
||||
// 1) p_src must be in global memory space, d_dst must be vgpr
|
||||
// 2) p_src to be a block-invariant pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t VectorSize>
|
||||
__device__ typename vector_type<T, VectorSize>::MemoryType amd_intrinsic_buffer_load(
|
||||
__device__ typename vector_type<T, VectorSize>::MemoryType amd_buffer_load(
|
||||
const T* p_src_block, index_t src_thread_data_offset, index_t src_const_data_offset);
|
||||
|
||||
// buffer_store requires:
|
||||
@@ -74,30 +158,23 @@ __device__ typename vector_type<T, VectorSize>::MemoryType amd_intrinsic_buffer_
|
||||
// 2) p_dst to be a block-invariant pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t VectorSize>
|
||||
__device__ void
|
||||
amd_intrinsic_buffer_store(const typename vector_type<T, VectorSize>::MemoryType& src,
|
||||
T* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset);
|
||||
__device__ void amd_buffer_store(const T* p_src,
|
||||
T* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset);
|
||||
|
||||
template <typename T, index_t VectorSize>
|
||||
__device__ void
|
||||
amd_intrinsic_buffer_atomic_add(const typename vector_type<T, VectorSize>::MemoryType& src,
|
||||
T* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset);
|
||||
__device__ void amd_buffer_atomic_add(const T* p_src,
|
||||
T* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset);
|
||||
|
||||
template <>
|
||||
__device__ float amd_intrinsic_buffer_load<float, 1>(const float* p_src_block,
|
||||
index_t src_thread_data_offset,
|
||||
index_t src_const_data_offset)
|
||||
__device__ float amd_buffer_load<float, 1>(const float* p_src_block,
|
||||
index_t src_thread_data_offset,
|
||||
index_t src_const_data_offset)
|
||||
{
|
||||
float dst;
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
|
||||
index_t src_const_addr_offset = src_const_data_offset * sizeof(float);
|
||||
|
||||
BufferLoadStoreDwordConfig<float> src_block_config;
|
||||
BufferAddressConfig<float> src_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
src_block_config.address[0] = const_cast<float*>(p_src_block);
|
||||
@@ -106,33 +183,19 @@ __device__ float amd_intrinsic_buffer_load<float, 1>(const float* p_src_block,
|
||||
// fill in byte 3
|
||||
src_block_config.range[3] = 0x00027000;
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC
|
||||
dst = __llvm_amdgcn_buffer_load(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
#else
|
||||
asm volatile(
|
||||
"\n \
|
||||
buffer_load_dword %0, %1, %2, %3 offen offset:0 \n \
|
||||
s_waitcnt 0 \n \
|
||||
"
|
||||
: "=v"(dst)
|
||||
: "v"(src_thread_addr_offset), "s"(src_block_config.data), "s"(src_const_addr_offset));
|
||||
#endif
|
||||
|
||||
return dst;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ float2_t amd_intrinsic_buffer_load<float, 2>(const float* p_src_block,
|
||||
index_t src_thread_data_offset,
|
||||
index_t src_const_data_offset)
|
||||
{
|
||||
float2_t dst;
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
|
||||
index_t src_const_addr_offset = src_const_data_offset * sizeof(float);
|
||||
|
||||
BufferLoadStoreDwordConfig<float> src_block_config;
|
||||
return __llvm_amdgcn_buffer_load_f32(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ float2_t amd_buffer_load<float, 2>(const float* p_src_block,
|
||||
index_t src_thread_data_offset,
|
||||
index_t src_const_data_offset)
|
||||
{
|
||||
BufferAddressConfig<float> src_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
src_block_config.address[0] = const_cast<float*>(p_src_block);
|
||||
@@ -141,33 +204,19 @@ __device__ float2_t amd_intrinsic_buffer_load<float, 2>(const float* p_src_block
|
||||
// fill in byte 3
|
||||
src_block_config.range[3] = 0x00027000;
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC
|
||||
dst = __llvm_amdgcn_buffer_loadx2(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
#else
|
||||
asm volatile(
|
||||
"\n \
|
||||
buffer_load_dwordx2 %0, %1, %2, %3 offen offset:0 \n \
|
||||
s_waitcnt 0 \n \
|
||||
"
|
||||
: "=v"(dst)
|
||||
: "v"(src_thread_addr_offset), "s"(src_block_config.data), "s"(src_const_addr_offset));
|
||||
#endif
|
||||
|
||||
return dst;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ float4_t amd_intrinsic_buffer_load<float, 4>(const float* p_src_block,
|
||||
index_t src_thread_data_offset,
|
||||
index_t src_const_data_offset)
|
||||
{
|
||||
float4_t dst;
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
|
||||
index_t src_const_addr_offset = src_const_data_offset * sizeof(float);
|
||||
|
||||
BufferLoadStoreDwordConfig<float> src_block_config;
|
||||
return __llvm_amdgcn_buffer_load_f32x2(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ float4_t amd_buffer_load<float, 4>(const float* p_src_block,
|
||||
index_t src_thread_data_offset,
|
||||
index_t src_const_data_offset)
|
||||
{
|
||||
BufferAddressConfig<float> src_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
src_block_config.address[0] = const_cast<float*>(p_src_block);
|
||||
@@ -176,32 +225,236 @@ __device__ float4_t amd_intrinsic_buffer_load<float, 4>(const float* p_src_block
|
||||
// fill in byte 3
|
||||
src_block_config.range[3] = 0x00027000;
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC
|
||||
dst = __llvm_amdgcn_buffer_loadx4(
|
||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
|
||||
index_t src_const_addr_offset = src_const_data_offset * sizeof(float);
|
||||
|
||||
return __llvm_amdgcn_buffer_load_f32x4(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ half_t amd_buffer_load<half_t, 1>(const half_t* p_src_block,
|
||||
index_t src_thread_data_offset,
|
||||
index_t src_const_data_offset)
|
||||
{
|
||||
BufferAddressConfig<half_t> src_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
src_block_config.address[0] = const_cast<half_t*>(p_src_block);
|
||||
// fill in byte 2
|
||||
src_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
src_block_config.range[3] = 0x00027000;
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_231101
|
||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
|
||||
index_t src_const_addr_offset = src_const_data_offset * sizeof(half_t);
|
||||
|
||||
return __llvm_amdgcn_buffer_load_f16(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
#else
|
||||
asm volatile(
|
||||
"\n \
|
||||
buffer_load_dwordx4 %0, %1, %2, %3 offen offset:0 \n \
|
||||
s_waitcnt 0 \n \
|
||||
"
|
||||
: "=v"(dst)
|
||||
: "v"(src_thread_addr_offset), "s"(src_block_config.data), "s"(src_const_addr_offset));
|
||||
return p_src_block[src_thread_data_offset + src_const_data_offset];
|
||||
#endif
|
||||
|
||||
return dst;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_intrinsic_buffer_store<float, 1>(const float& src,
|
||||
float* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
__device__ half2_t amd_buffer_load<half_t, 2>(const half_t* p_src_block,
|
||||
index_t src_thread_data_offset,
|
||||
index_t src_const_data_offset)
|
||||
{
|
||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
|
||||
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
|
||||
BufferAddressConfig<half_t> src_block_config;
|
||||
|
||||
BufferLoadStoreDwordConfig<float> dst_block_config;
|
||||
// fill in byte 0 - 1
|
||||
src_block_config.address[0] = const_cast<half_t*>(p_src_block);
|
||||
// fill in byte 2
|
||||
src_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
src_block_config.range[3] = 0x00027000;
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
|
||||
index_t src_const_addr_offset = src_const_data_offset * sizeof(half_t);
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_231101
|
||||
return __llvm_amdgcn_buffer_load_f16x2(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
#else
|
||||
float dst_out_tmp = __llvm_amdgcn_buffer_load_f32(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
|
||||
return *reinterpret_cast<half2_t*>(&dst_out_tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ half4_t amd_buffer_load<half_t, 4>(const half_t* p_src_block,
|
||||
index_t src_thread_data_offset,
|
||||
index_t src_const_data_offset)
|
||||
{
|
||||
BufferAddressConfig<half_t> src_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
src_block_config.address[0] = const_cast<half_t*>(p_src_block);
|
||||
// fill in byte 2
|
||||
src_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
src_block_config.range[3] = 0x00027000;
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
|
||||
index_t src_const_addr_offset = src_const_data_offset * sizeof(half_t);
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_231101
|
||||
return __llvm_amdgcn_buffer_load_f16x4(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
#else
|
||||
float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
|
||||
return *reinterpret_cast<half4_t*>(&dst_out_tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ half8_t amd_buffer_load<half_t, 8>(const half_t* p_src_block,
|
||||
index_t src_thread_data_offset,
|
||||
index_t src_const_data_offset)
|
||||
{
|
||||
BufferAddressConfig<half_t> src_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
src_block_config.address[0] = const_cast<half_t*>(p_src_block);
|
||||
// fill in byte 2
|
||||
src_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
src_block_config.range[3] = 0x00027000;
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t);
|
||||
index_t src_const_addr_offset = src_const_data_offset * sizeof(half_t);
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_231101
|
||||
static_assert(false, "wrong! not supported");
|
||||
#else
|
||||
float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
|
||||
return *reinterpret_cast<half8_t*>(&dst_out_tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ ushort amd_buffer_load<ushort, 1>(const ushort* p_src_block,
|
||||
index_t src_thread_data_offset,
|
||||
index_t src_const_data_offset)
|
||||
{
|
||||
BufferAddressConfig<ushort> src_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
src_block_config.address[0] = const_cast<ushort*>(p_src_block);
|
||||
// fill in byte 2
|
||||
src_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
src_block_config.range[3] = 0x00027000;
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_231101
|
||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
|
||||
index_t src_const_addr_offset = src_const_data_offset * sizeof(ushort);
|
||||
|
||||
return __llvm_amdgcn_buffer_load_bf16(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
#else
|
||||
return p_src_block[src_thread_data_offset + src_const_data_offset];
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ ushort2_t amd_buffer_load<ushort, 2>(const ushort* p_src_block,
|
||||
index_t src_thread_data_offset,
|
||||
index_t src_const_data_offset)
|
||||
{
|
||||
BufferAddressConfig<ushort> src_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
src_block_config.address[0] = const_cast<ushort*>(p_src_block);
|
||||
// fill in byte 2
|
||||
src_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
src_block_config.range[3] = 0x00027000;
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
|
||||
index_t src_const_addr_offset = src_const_data_offset * sizeof(ushort);
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_231101
|
||||
return __llvm_amdgcn_buffer_load_bf16x2(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
#else
|
||||
float dst_out_tmp = __llvm_amdgcn_buffer_load_f32(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
|
||||
return *reinterpret_cast<ushort2_t*>(&dst_out_tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ ushort4_t amd_buffer_load<ushort, 4>(const ushort* p_src_block,
|
||||
index_t src_thread_data_offset,
|
||||
index_t src_const_data_offset)
|
||||
{
|
||||
BufferAddressConfig<ushort> src_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
src_block_config.address[0] = const_cast<ushort*>(p_src_block);
|
||||
// fill in byte 2
|
||||
src_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
src_block_config.range[3] = 0x00027000;
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
|
||||
index_t src_const_addr_offset = src_const_data_offset * sizeof(ushort);
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_231101
|
||||
return __llvm_amdgcn_buffer_load_bf16x4(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
#else
|
||||
float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
|
||||
return *reinterpret_cast<ushort4_t*>(&dst_out_tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ ushort8_t amd_buffer_load<ushort, 8>(const ushort* p_src_block,
|
||||
index_t src_thread_data_offset,
|
||||
index_t src_const_data_offset)
|
||||
{
|
||||
BufferAddressConfig<ushort> src_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
src_block_config.address[0] = const_cast<ushort*>(p_src_block);
|
||||
// fill in byte 2
|
||||
src_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
src_block_config.range[3] = 0x00027000;
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort);
|
||||
index_t src_const_addr_offset = src_const_data_offset * sizeof(ushort);
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_231101
|
||||
static_assert(false, "wrong! not implemented");
|
||||
#else
|
||||
float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
|
||||
return *reinterpret_cast<ushort8_t*>(&dst_out_tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_buffer_store<float, 1>(const float* p_src,
|
||||
float* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
BufferAddressConfig<float> dst_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
dst_block_config.address[0] = p_dst_block;
|
||||
@@ -210,35 +463,24 @@ __device__ void amd_intrinsic_buffer_store<float, 1>(const float& src,
|
||||
// fill in byte 3
|
||||
dst_block_config.range[3] = 0x00027000;
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC
|
||||
__llvm_amdgcn_buffer_store(src,
|
||||
dst_block_config.data,
|
||||
0,
|
||||
dst_thread_addr_offset + dst_const_addr_offset,
|
||||
false,
|
||||
false);
|
||||
#else
|
||||
asm volatile("\n \
|
||||
buffer_store_dword %1, %2, %0, %3 offen offset:0 \n \
|
||||
"
|
||||
:
|
||||
: "s"(dst_block_config.data),
|
||||
"v"(src),
|
||||
"v"(dst_thread_addr_offset),
|
||||
"s"(dst_const_addr_offset));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_intrinsic_buffer_store<float, 2>(const float2_t& src,
|
||||
float* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
|
||||
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
|
||||
|
||||
BufferLoadStoreDwordConfig<float> dst_block_config;
|
||||
__llvm_amdgcn_buffer_store_f32(*p_src,
|
||||
dst_block_config.data,
|
||||
0,
|
||||
dst_thread_addr_offset + dst_const_addr_offset,
|
||||
false,
|
||||
false);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_buffer_store<float, 2>(const float* p_src,
|
||||
float* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
BufferAddressConfig<float> dst_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
dst_block_config.address[0] = p_dst_block;
|
||||
@@ -247,35 +489,24 @@ __device__ void amd_intrinsic_buffer_store<float, 2>(const float2_t& src,
|
||||
// fill in byte 3
|
||||
dst_block_config.range[3] = 0x00027000;
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC
|
||||
__llvm_amdgcn_buffer_storex2(src,
|
||||
dst_block_config.data,
|
||||
0,
|
||||
dst_thread_addr_offset + dst_const_addr_offset,
|
||||
false,
|
||||
false);
|
||||
#else
|
||||
asm volatile("\n \
|
||||
buffer_store_dwordx2 %1, %2, %0, %3 offen offset:0 \n \
|
||||
"
|
||||
:
|
||||
: "s"(dst_block_config.data),
|
||||
"v"(src),
|
||||
"v"(dst_thread_addr_offset),
|
||||
"s"(dst_const_addr_offset));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_intrinsic_buffer_store<float, 4>(const float4_t& src,
|
||||
float* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
|
||||
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
|
||||
|
||||
BufferLoadStoreDwordConfig<float> dst_block_config;
|
||||
__llvm_amdgcn_buffer_store_f32x2(*reinterpret_cast<const float2_t*>(p_src),
|
||||
dst_block_config.data,
|
||||
0,
|
||||
dst_thread_addr_offset + dst_const_addr_offset,
|
||||
false,
|
||||
false);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_buffer_store<float, 4>(const float* p_src,
|
||||
float* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
BufferAddressConfig<float> dst_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
dst_block_config.address[0] = p_dst_block;
|
||||
@@ -284,35 +515,24 @@ __device__ void amd_intrinsic_buffer_store<float, 4>(const float4_t& src,
|
||||
// fill in byte 3
|
||||
dst_block_config.range[3] = 0x00027000;
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC
|
||||
__llvm_amdgcn_buffer_storex4(src,
|
||||
dst_block_config.data,
|
||||
0,
|
||||
dst_thread_addr_offset + dst_const_addr_offset,
|
||||
false,
|
||||
false);
|
||||
#else
|
||||
asm volatile("\n \
|
||||
buffer_store_dwordx4 %1, %2, %0, %3 offen offset:0 \n \
|
||||
"
|
||||
:
|
||||
: "s"(dst_block_config.data),
|
||||
"v"(src),
|
||||
"v"(dst_thread_addr_offset),
|
||||
"s"(dst_const_addr_offset));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_intrinsic_buffer_atomic_add<float, 1>(const float& src,
|
||||
float* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
|
||||
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
|
||||
|
||||
BufferLoadStoreDwordConfig<float> dst_block_config;
|
||||
__llvm_amdgcn_buffer_store_f32x4(*reinterpret_cast<const float4_t*>(p_src),
|
||||
dst_block_config.data,
|
||||
0,
|
||||
dst_thread_addr_offset + dst_const_addr_offset,
|
||||
false,
|
||||
false);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_buffer_store<half_t, 1>(const half_t* p_src,
|
||||
half_t* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
BufferAddressConfig<half_t> dst_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
dst_block_config.address[0] = p_dst_block;
|
||||
@@ -321,13 +541,246 @@ __device__ void amd_intrinsic_buffer_atomic_add<float, 1>(const float& src,
|
||||
// fill in byte 3
|
||||
dst_block_config.range[3] = 0x00027000;
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC
|
||||
__llvm_amdgcn_buffer_atomic_add(
|
||||
src, dst_block_config.data, 0, dst_thread_addr_offset + dst_const_addr_offset, false);
|
||||
#if !CK_WORKAROUND_SWDEV_231101
|
||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t);
|
||||
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(half_t);
|
||||
|
||||
__llvm_amdgcn_buffer_store_f16(*p_src,
|
||||
dst_block_config.data,
|
||||
0,
|
||||
dst_thread_addr_offset + dst_const_addr_offset,
|
||||
false,
|
||||
false);
|
||||
#else
|
||||
static_assert(false, " wrong! not implemented");
|
||||
p_dst_block[dst_thread_data_offset + dst_const_data_offset] = *p_src;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_buffer_store<half_t, 2>(const half_t* p_src,
|
||||
half_t* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
BufferAddressConfig<half_t> dst_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
dst_block_config.address[0] = p_dst_block;
|
||||
// fill in byte 2
|
||||
dst_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
dst_block_config.range[3] = 0x00027000;
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t);
|
||||
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(half_t);
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_231101
|
||||
__llvm_amdgcn_buffer_store_f16x2(*reinterpret_cast<const half2_t*>(p_src),
|
||||
dst_block_config.data,
|
||||
0,
|
||||
dst_thread_addr_offset + dst_const_addr_offset,
|
||||
false,
|
||||
false);
|
||||
#else
|
||||
const float* p_src_tmp = reinterpret_cast<const float*>(p_src);
|
||||
|
||||
__llvm_amdgcn_buffer_store_f32(*p_src_tmp,
|
||||
dst_block_config.data,
|
||||
0,
|
||||
dst_thread_addr_offset + dst_const_addr_offset,
|
||||
false,
|
||||
false);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_buffer_store<half_t, 4>(const half_t* p_src,
|
||||
half_t* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t);
|
||||
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(half_t);
|
||||
|
||||
BufferAddressConfig<half_t> dst_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
dst_block_config.address[0] = p_dst_block;
|
||||
// fill in byte 2
|
||||
dst_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
dst_block_config.range[3] = 0x00027000;
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_231101
|
||||
__llvm_amdgcn_buffer_store_f16x4(*reinterpret_cast<const half4_t*>(p_src),
|
||||
dst_block_config.data,
|
||||
0,
|
||||
dst_thread_addr_offset + dst_const_addr_offset,
|
||||
false,
|
||||
false);
|
||||
#else
|
||||
const float2_t* p_src_tmp = reinterpret_cast<const float2_t*>(p_src);
|
||||
|
||||
__llvm_amdgcn_buffer_store_f32x2(*p_src_tmp,
|
||||
dst_block_config.data,
|
||||
0,
|
||||
dst_thread_addr_offset + dst_const_addr_offset,
|
||||
false,
|
||||
false);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_buffer_store<ushort, 1>(const ushort* p_src,
|
||||
ushort* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
BufferAddressConfig<ushort> dst_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
dst_block_config.address[0] = p_dst_block;
|
||||
// fill in byte 2
|
||||
dst_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
dst_block_config.range[3] = 0x00027000;
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_231101
|
||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort);
|
||||
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(ushort);
|
||||
|
||||
__llvm_amdgcn_buffer_store_bf16(*p_src,
|
||||
dst_block_config.data,
|
||||
0,
|
||||
dst_thread_addr_offset + dst_const_addr_offset,
|
||||
false,
|
||||
false);
|
||||
#else
|
||||
p_dst_block[dst_thread_data_offset + dst_const_data_offset] = *p_src;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_buffer_store<ushort, 2>(const ushort* p_src,
|
||||
ushort* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
BufferAddressConfig<ushort> dst_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
dst_block_config.address[0] = p_dst_block;
|
||||
// fill in byte 2
|
||||
dst_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
dst_block_config.range[3] = 0x00027000;
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort);
|
||||
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(ushort);
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_231101
|
||||
__llvm_amdgcn_buffer_store_bf16x2(*p_src,
|
||||
dst_block_config.data,
|
||||
0,
|
||||
dst_thread_addr_offset + dst_const_addr_offset,
|
||||
false,
|
||||
false);
|
||||
#else
|
||||
const float* p_src_tmp = reinterpret_cast<const float*>(p_src);
|
||||
|
||||
__llvm_amdgcn_buffer_store_f32(*p_src_tmp,
|
||||
dst_block_config.data,
|
||||
0,
|
||||
dst_thread_addr_offset + dst_const_addr_offset,
|
||||
false,
|
||||
false);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_buffer_store<ushort, 4>(const ushort* p_src,
|
||||
ushort* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
BufferAddressConfig<ushort> dst_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
dst_block_config.address[0] = p_dst_block;
|
||||
// fill in byte 2
|
||||
dst_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
dst_block_config.range[3] = 0x00027000;
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort);
|
||||
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(ushort);
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_231101
|
||||
__llvm_amdgcn_buffer_store_bf16x4(*p_src,
|
||||
dst_block_config.data,
|
||||
0,
|
||||
dst_thread_addr_offset + dst_const_addr_offset,
|
||||
false,
|
||||
false);
|
||||
#else
|
||||
const float2_t* p_src_tmp = reinterpret_cast<const float2_t*>(p_src);
|
||||
|
||||
__llvm_amdgcn_buffer_store_f32x2(*p_src_tmp,
|
||||
dst_block_config.data,
|
||||
0,
|
||||
dst_thread_addr_offset + dst_const_addr_offset,
|
||||
false,
|
||||
false);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_buffer_atomic_add<float, 1>(const float* p_src,
|
||||
float* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
BufferAddressConfig<float> dst_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
dst_block_config.address[0] = p_dst_block;
|
||||
// fill in byte 2
|
||||
dst_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
dst_block_config.range[3] = 0x00027000;
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
|
||||
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
|
||||
|
||||
__llvm_amdgcn_buffer_atomic_add_f32(
|
||||
*p_src, dst_block_config.data, 0, dst_thread_addr_offset + dst_const_addr_offset, false);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_buffer_atomic_add<float, 2>(const float* p_src,
|
||||
float* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
for(index_t i = 0; i < 2; ++i)
|
||||
{
|
||||
amd_buffer_atomic_add<float, 1>(
|
||||
&p_src[i], p_dst_block, dst_thread_data_offset, dst_const_data_offset + i);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_buffer_atomic_add<float, 4>(const float* p_src,
|
||||
float* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
for(index_t i = 0; i < 4; ++i)
|
||||
{
|
||||
amd_buffer_atomic_add<float, 1>(
|
||||
&p_src[i], p_dst_block, dst_thread_data_offset, dst_const_data_offset + i);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -16,15 +16,12 @@
|
||||
#include "functional3.hpp"
|
||||
#include "functional4.hpp"
|
||||
#include "in_memory_operation.hpp"
|
||||
#include "synchronization.hpp"
|
||||
|
||||
#if CK_USE_AMD_INLINE_ASM
|
||||
#include "amd_inline_asm.hpp"
|
||||
#endif
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
#endif
|
||||
|
||||
#if CK_USE_AMD_XDLOPS
|
||||
#include "amd_xdlops.hpp"
|
||||
#endif
|
||||
|
||||
@@ -25,11 +25,7 @@
|
||||
#define CK_USE_AMD_BUFFER_ADDRESSING 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC
|
||||
#define CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC 1
|
||||
#endif
|
||||
|
||||
// only support gfx908
|
||||
// only gfx908 support native floating point atomic add
|
||||
#ifndef CK_USE_AMD_BUFFER_ATOMIC_ADD
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_ADD 0
|
||||
#endif
|
||||
@@ -47,6 +43,11 @@
|
||||
#define CK_USE_AMD_XDLOPS_EMULATE 0 // For internal debug purposes
|
||||
#endif
|
||||
|
||||
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
|
||||
#ifndef CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
#define CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
|
||||
#endif
|
||||
|
||||
// experimental implementation
|
||||
#define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 1
|
||||
#define CK_EXPERIMENTAL_TENSOR_COORDINATE_USE_CALCULATE_OFFSET_DIFF 0
|
||||
@@ -54,8 +55,24 @@
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK
|
||||
#define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK
|
||||
#define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK 0
|
||||
#endif
|
||||
|
||||
// workaround: put all workaround here
|
||||
// workaround for unnecessary VGPA <--> AGRP data movement when using mfma LLVM intrinsic
|
||||
#ifndef CK_WORKAROUND_SWDEV_229564
|
||||
#define CK_WORKAROUND_SWDEV_229564 1
|
||||
#endif
|
||||
// workaround for buffer load/store fp16/bfp16 intrinsic bug
|
||||
#ifndef CK_WORKAROUND_SWDEV_231101
|
||||
#define CK_WORKAROUND_SWDEV_231101 1
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
#ifndef CK_CONFIG_NVIDIA_HPP
|
||||
#define CK_CONFIG_NVIDIA_HPP
|
||||
|
||||
#include "cuda_runtime.h"
|
||||
#include "cuda_fp16.h"
|
||||
#include "nvToolsExt.h"
|
||||
#include "helper_cuda.h"
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <nvToolsExt.h>
|
||||
|
||||
// index type: unsigned or signed
|
||||
#define CK_UNSIGNED_INDEX_TYPE 0
|
||||
@@ -19,6 +18,7 @@
|
||||
#define CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC 0
|
||||
#define CK_USE_AMD_XDLOPS 0
|
||||
#define CK_USE_AMD_XDLOPS_INLINE_ASM 0
|
||||
#define CK_USE_AMD_XDLOPS_EMULATE 0
|
||||
|
||||
// experimental implementation
|
||||
#define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 0
|
||||
@@ -32,16 +32,16 @@ namespace ck {
|
||||
|
||||
enum AddressSpace
|
||||
{
|
||||
generic,
|
||||
global,
|
||||
lds,
|
||||
vgpr
|
||||
Generic,
|
||||
Global,
|
||||
Lds,
|
||||
Vgpr
|
||||
};
|
||||
|
||||
enum InMemoryDataOperation
|
||||
{
|
||||
none,
|
||||
atomic_add
|
||||
Set,
|
||||
AtomicAdd
|
||||
};
|
||||
|
||||
#if CK_UNSIGNED_INDEX_TYPE
|
||||
|
||||
@@ -11,12 +11,15 @@ typedef float float16_t __attribute__((ext_vector_type(16)));
|
||||
typedef float float32_t __attribute__((ext_vector_type(32)));
|
||||
|
||||
// float16
|
||||
typedef _Float16 half_t;
|
||||
typedef _Float16 half2_t __attribute__((ext_vector_type(2)));
|
||||
typedef _Float16 half4_t __attribute__((ext_vector_type(4)));
|
||||
typedef _Float16 half8_t __attribute__((ext_vector_type(8)));
|
||||
|
||||
// bfloat16
|
||||
typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
|
||||
typedef ushort ushort4_t __attribute__((ext_vector_type(4)));
|
||||
typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
|
||||
|
||||
template <class T, index_t N>
|
||||
struct vector_type
|
||||
@@ -83,37 +86,37 @@ struct vector_type<float, 4>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half, 1>
|
||||
struct vector_type<half_t, 1>
|
||||
{
|
||||
using MemoryType = half;
|
||||
using MemoryType = half_t;
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>)
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number<I>)
|
||||
{
|
||||
static_assert(I < 1, "wrong");
|
||||
*(reinterpret_cast<half*>(&v) + I) = s;
|
||||
*(reinterpret_cast<half_t*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half, 2>
|
||||
struct vector_type<half_t, 2>
|
||||
{
|
||||
using MemoryType = half2_t;
|
||||
|
||||
union DataType
|
||||
{
|
||||
MemoryType vector;
|
||||
half scalar[2];
|
||||
half_t scalar[2];
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>)
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number<I>)
|
||||
{
|
||||
static_assert(I < 2, "wrong");
|
||||
*(reinterpret_cast<half*>(&v) + I) = s;
|
||||
*(reinterpret_cast<half_t*>(&v) + I) = s;
|
||||
}
|
||||
|
||||
__host__ __device__ static MemoryType Pack(half s0, half s1)
|
||||
__host__ __device__ static MemoryType Pack(half_t s0, half_t s1)
|
||||
{
|
||||
DataType data;
|
||||
data.scalar[0] = s0;
|
||||
@@ -123,24 +126,24 @@ struct vector_type<half, 2>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half, 4>
|
||||
struct vector_type<half_t, 4>
|
||||
{
|
||||
using MemoryType = half4_t;
|
||||
|
||||
union DataType
|
||||
{
|
||||
MemoryType vector;
|
||||
half scalar[4];
|
||||
half_t scalar[4];
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>)
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number<I>)
|
||||
{
|
||||
static_assert(I < 4, "wrong");
|
||||
*(reinterpret_cast<half*>(&v) + I) = s;
|
||||
*(reinterpret_cast<half_t*>(&v) + I) = s;
|
||||
}
|
||||
|
||||
__host__ __device__ static MemoryType Pack(half s0, half s1, half s2, half s3)
|
||||
__host__ __device__ static MemoryType Pack(half_t s0, half_t s1, half_t s2, half_t s3)
|
||||
{
|
||||
DataType data;
|
||||
data.scalar[0] = s0;
|
||||
@@ -151,6 +154,25 @@ struct vector_type<half, 4>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half_t, 8>
|
||||
{
|
||||
using MemoryType = half8_t;
|
||||
|
||||
union DataType
|
||||
{
|
||||
MemoryType vector;
|
||||
half_t scalar[8];
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number<I>)
|
||||
{
|
||||
static_assert(I < 8, "wrong");
|
||||
*(reinterpret_cast<half_t*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<ushort, 1>
|
||||
{
|
||||
@@ -220,6 +242,25 @@ struct vector_type<ushort, 4>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<ushort, 8>
|
||||
{
|
||||
using MemoryType = ushort8_t;
|
||||
|
||||
union DataType
|
||||
{
|
||||
MemoryType vector;
|
||||
ushort scalar[8];
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, ushort s, Number<I>)
|
||||
{
|
||||
static_assert(I < 8, "wrong");
|
||||
*(reinterpret_cast<ushort*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
// data type conversion
|
||||
template <typename T>
|
||||
struct type_convert
|
||||
@@ -250,12 +291,40 @@ struct inner_product_with_conversion
|
||||
{
|
||||
static constexpr auto convert = type_convert<T>();
|
||||
|
||||
__device__ T operator()(float4_t a, float4_t b) const
|
||||
{
|
||||
const float* p_a_float = reinterpret_cast<const float*>(&a);
|
||||
const float* p_b_float = reinterpret_cast<const float*>(&b);
|
||||
|
||||
T acc = 0;
|
||||
for(index_t v = 0; v < 4; ++v)
|
||||
{
|
||||
acc += convert(p_a_float[v]) * convert(p_b_float[v]);
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
|
||||
__device__ T operator()(float2_t a, float2_t b) const
|
||||
{
|
||||
const float* p_a_float = reinterpret_cast<const float*>(&a);
|
||||
const float* p_b_float = reinterpret_cast<const float*>(&b);
|
||||
|
||||
T acc = 0;
|
||||
for(index_t v = 0; v < 2; ++v)
|
||||
{
|
||||
acc += convert(p_a_float[v]) * convert(p_b_float[v]);
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
|
||||
__device__ T operator()(float a, float b) const { return convert(a) * convert(b); }
|
||||
|
||||
__device__ T operator()(half2_t a, half2_t b) const
|
||||
{
|
||||
const half* p_a_half = reinterpret_cast<const half*>(&a);
|
||||
const half* p_b_half = reinterpret_cast<const half*>(&b);
|
||||
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
|
||||
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
|
||||
|
||||
T acc = 0;
|
||||
for(index_t v = 0; v < 2; ++v)
|
||||
@@ -268,8 +337,8 @@ struct inner_product_with_conversion
|
||||
|
||||
__device__ T operator()(half4_t a, half4_t b) const
|
||||
{
|
||||
const half* p_a_half = reinterpret_cast<const half*>(&a);
|
||||
const half* p_b_half = reinterpret_cast<const half*>(&b);
|
||||
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
|
||||
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
|
||||
|
||||
T acc = 0;
|
||||
for(index_t v = 0; v < 4; ++v)
|
||||
@@ -279,6 +348,19 @@ struct inner_product_with_conversion
|
||||
return acc;
|
||||
}
|
||||
|
||||
__device__ T operator()(half8_t a, half8_t b) const
|
||||
{
|
||||
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
|
||||
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
|
||||
|
||||
T acc = 0;
|
||||
for(index_t v = 0; v < 8; ++v)
|
||||
{
|
||||
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
|
||||
}
|
||||
return acc;
|
||||
}
|
||||
|
||||
__device__ T operator()(ushort2_t a, ushort2_t b) const
|
||||
{
|
||||
const ushort* p_a_bfloat16 = reinterpret_cast<const ushort*>(&a);
|
||||
@@ -305,6 +387,19 @@ struct inner_product_with_conversion
|
||||
}
|
||||
return acc;
|
||||
}
|
||||
|
||||
__device__ T operator()(ushort8_t a, ushort8_t b) const
|
||||
{
|
||||
const ushort* p_a_bfloat16 = reinterpret_cast<const ushort*>(&a);
|
||||
const ushort* p_b_bfloat16 = reinterpret_cast<const ushort*>(&b);
|
||||
|
||||
T acc = 0;
|
||||
for(index_t v = 0; v < 8; ++v)
|
||||
{
|
||||
acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]);
|
||||
}
|
||||
return acc;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -13,8 +13,18 @@ namespace ck {
|
||||
using float2_t = float2;
|
||||
using float4_t = float4;
|
||||
|
||||
// float16
|
||||
// float
|
||||
typedef float float32_t __attribute__((ext_vector_type(32)));
|
||||
|
||||
// bfloat16
|
||||
typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
|
||||
typedef ushort ushort4_t __attribute__((ext_vector_type(4)));
|
||||
typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
|
||||
|
||||
// fp16
|
||||
using half_t = half;
|
||||
using half2_t = half2;
|
||||
using half4_t = float2;
|
||||
|
||||
template <class T, index_t N>
|
||||
struct vector_type
|
||||
@@ -81,37 +91,37 @@ struct vector_type<float, 4>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half, 1>
|
||||
struct vector_type<half_t, 1>
|
||||
{
|
||||
using MemoryType = half;
|
||||
using MemoryType = half_t;
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>)
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number<I>)
|
||||
{
|
||||
static_assert(I < 1, "wrong");
|
||||
*(reinterpret_cast<half*>(&v) + I) = s;
|
||||
*(reinterpret_cast<half_t*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half, 2>
|
||||
struct vector_type<half_t, 2>
|
||||
{
|
||||
using MemoryType = half2_t;
|
||||
|
||||
union DataType
|
||||
{
|
||||
MemoryType vector;
|
||||
half scalar[2];
|
||||
half_t scalar[2];
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>)
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number<I>)
|
||||
{
|
||||
static_assert(I < 2, "wrong");
|
||||
*(reinterpret_cast<half*>(&v) + I) = s;
|
||||
*(reinterpret_cast<half_t*>(&v) + I) = s;
|
||||
}
|
||||
|
||||
__host__ __device__ static MemoryType Pack(half s0, half s1)
|
||||
__host__ __device__ static MemoryType Pack(half_t s0, half_t s1)
|
||||
{
|
||||
DataType data;
|
||||
data.scalar[0] = s0;
|
||||
@@ -140,8 +150,8 @@ struct inner_product_with_conversion
|
||||
|
||||
__device__ T operator()(half2_t a, half2_t b) const
|
||||
{
|
||||
const half* p_a_half = reinterpret_cast<const half*>(&a);
|
||||
const half* p_b_half = reinterpret_cast<const half*>(&b);
|
||||
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
|
||||
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
|
||||
|
||||
T acc = 0;
|
||||
for(index_t v = 0; v < 2; ++v)
|
||||
@@ -151,6 +161,19 @@ struct inner_product_with_conversion
|
||||
|
||||
return acc;
|
||||
}
|
||||
|
||||
__device__ T operator()(half4_t a, half4_t b) const
|
||||
{
|
||||
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
|
||||
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
|
||||
|
||||
T acc = 0;
|
||||
for(index_t v = 0; v < 4; ++v)
|
||||
{
|
||||
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
|
||||
}
|
||||
return acc;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -2,91 +2,159 @@
|
||||
#define CK_IN_MEMORY_OPERATION_AMD_HPP
|
||||
|
||||
#include "float_type.hpp"
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T,
|
||||
index_t DataPerAccess,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace>
|
||||
__device__ void set_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
template <typename T>
|
||||
__device__ void atomic_add_impl(T* p_dst, T src)
|
||||
{
|
||||
atomicAdd(p_dst, src);
|
||||
}
|
||||
|
||||
// atomicAdd for float does not support vector type
|
||||
template <>
|
||||
__device__ void atomic_add_impl<float2_t>(float2_t* p_dst, float2_t src)
|
||||
{
|
||||
float* p_dst_float = reinterpret_cast<float*>(p_dst);
|
||||
const float* p_src_float = reinterpret_cast<const float*>(&src);
|
||||
|
||||
for(index_t i = 0; i < 2; ++i)
|
||||
{
|
||||
atomicAdd(&(p_dst_float[i]), p_src_float[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void atomic_add_impl<float4_t>(float4_t* p_dst, float4_t src)
|
||||
{
|
||||
float* p_dst_float = reinterpret_cast<float*>(p_dst);
|
||||
const float* p_src_float = reinterpret_cast<const float*>(&src);
|
||||
|
||||
for(index_t i = 0; i < 4; ++i)
|
||||
{
|
||||
atomicAdd(&(p_dst_float[i]), p_src_float[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t DataPerAccess>
|
||||
struct SetData
|
||||
{
|
||||
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
|
||||
|
||||
// This version is only for compatibility, don't use this version if possible
|
||||
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
|
||||
__device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const
|
||||
{
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
|
||||
}
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
// TODO: use static_if::ElseIf, instead of nested static_if
|
||||
static_if<SrcAddressSpace == AddressSpace::Global &&
|
||||
DstAddressSpace == AddressSpace::Vgpr>{}([&](auto) {
|
||||
// buffer_load requires:
|
||||
// 1) p_src must be in global memory space, d_dst must be vgpr
|
||||
// 2) p_src to be a block-invariant pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
// buffer_load requires:
|
||||
// 1) p_src must be in global memory space, d_dst must be vgpr
|
||||
// 2) p_src to be a block-invariant pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <>
|
||||
__device__ void Run<AddressSpace::Global, AddressSpace::Vgpr>(const T* p_src,
|
||||
index_t src_offset,
|
||||
T* p_dst,
|
||||
index_t dst_offset) const
|
||||
{
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
|
||||
amd_intrinsic_buffer_load<T, DataPerAccess>(p_src, src_offset, 0);
|
||||
}).Else([&](auto) {
|
||||
static_if<SrcAddressSpace == AddressSpace::Vgpr &&
|
||||
DstAddressSpace == AddressSpace::Global>{}([&](auto) {
|
||||
// buffer_store requires:
|
||||
// 1) p_src must be in vgpr space, d_dst must be global memory
|
||||
// 2) p_dst to be a block-invariant pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
amd_intrinsic_buffer_store<T, DataPerAccess>(
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]), p_dst, dst_offset, 0);
|
||||
}).Else([&](auto) {
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
|
||||
});
|
||||
});
|
||||
#else
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
|
||||
#endif
|
||||
}
|
||||
amd_buffer_load<T, DataPerAccess>(p_src, src_offset, 0);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
index_t DataPerAccess,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace>
|
||||
__device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
// buffer_store requires:
|
||||
// 1) p_src must be in vgpr space, d_dst must be global memory
|
||||
// 2) p_dst to be a block-invariant pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <>
|
||||
__device__ void Run<AddressSpace::Vgpr, AddressSpace::Global>(const T* p_src,
|
||||
index_t src_offset,
|
||||
T* p_dst,
|
||||
index_t dst_offset) const
|
||||
{
|
||||
amd_buffer_store<T, DataPerAccess>(&(p_src[src_offset]), p_dst, dst_offset, 0);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <typename T, index_t DataPerAccess>
|
||||
struct AtomicAddData
|
||||
{
|
||||
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
|
||||
|
||||
static_if<SrcAddressSpace == AddressSpace::Vgpr &&
|
||||
DstAddressSpace == AddressSpace::Global>{}([&](auto) {
|
||||
#if CK_USE_AMD_BUFFER_ATOMIC_ADD
|
||||
amd_intrinsic_buffer_atomic_add<T, DataPerAccess>(
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]), p_dst, dst_offset, 0);
|
||||
#else
|
||||
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
|
||||
// This version is only for compatibility, don't use this version if possible
|
||||
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
|
||||
__device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const
|
||||
{
|
||||
atomic_add_impl(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
|
||||
}
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING && CK_USE_AMD_BUFFER_ATOMIC_ADD
|
||||
// buffer_atomic_add requires:
|
||||
// 1) p_src must be in vgpr space, d_dst must be global memory
|
||||
// 2) p_dst to be a block-invariant pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <>
|
||||
__device__ void Run<AddressSpace::Vgpr, AddressSpace::Global>(const T* p_src,
|
||||
index_t src_offset,
|
||||
T* p_dst,
|
||||
index_t dst_offset) const
|
||||
{
|
||||
amd_buffer_atomic_add<T, DataPerAccess>(&(p_src[src_offset]), p_dst, dst_offset, 0);
|
||||
}
|
||||
#endif
|
||||
}).Else([&](auto fwd) {
|
||||
static_assert(fwd(false), "atomic_add doesn't support this memory space");
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
index_t DataPerAccess,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace,
|
||||
InMemoryDataOperation DstInMemOp>
|
||||
InMemoryDataOperation DstInMemOp,
|
||||
index_t SrcDataStride = 1,
|
||||
index_t DstDataStride = 1>
|
||||
__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
{
|
||||
static_assert(DstInMemOp == InMemoryDataOperation::Set ||
|
||||
DstInMemOp == InMemoryDataOperation::AtomicAdd,
|
||||
"wrong! InMemoryDataOperation not supported!");
|
||||
|
||||
// TODO: use static_if::ElseIf
|
||||
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
|
||||
set_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, p_dst, dst_offset);
|
||||
});
|
||||
// keep it simple, don't use static_if here, otherwise compiler will do weird things
|
||||
if(SrcDataStride == 1 && DstDataStride == 1)
|
||||
{
|
||||
// TODO: use static_if::ElseIf
|
||||
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
|
||||
SetData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, p_dst, dst_offset);
|
||||
});
|
||||
|
||||
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
|
||||
atomic_add_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, p_dst, dst_offset);
|
||||
});
|
||||
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
|
||||
AtomicAddData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, p_dst, dst_offset);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
for(index_t i = 0; i < DataPerAccess; i++)
|
||||
{
|
||||
// TODO: use static_if::ElseIf
|
||||
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
|
||||
SetData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride);
|
||||
});
|
||||
|
||||
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
|
||||
AtomicAddData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -3,56 +3,106 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T,
|
||||
index_t DataPerAccess,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace>
|
||||
__device__ void copy_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
template <typename T>
|
||||
__device__ void atomic_add_impl(T* p_dst, T src)
|
||||
{
|
||||
atomicAdd(p_dst, src);
|
||||
}
|
||||
|
||||
// atomicAdd for float does not support vector type
|
||||
template <>
|
||||
__device__ void atomic_add_impl<float2_t>(float2_t* p_dst, float2_t src)
|
||||
{
|
||||
float* p_dst_float = reinterpret_cast<float*>(p_dst);
|
||||
const float* p_src_float = reinterpret_cast<const float*>(&src);
|
||||
|
||||
for(index_t i = 0; i < 2; ++i)
|
||||
{
|
||||
atomicAdd(&(p_dst_float[i]), p_src_float[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void atomic_add_impl<float4_t>(float4_t* p_dst, float4_t src)
|
||||
{
|
||||
float* p_dst_float = reinterpret_cast<float*>(p_dst);
|
||||
const float* p_src_float = reinterpret_cast<const float*>(&src);
|
||||
|
||||
for(index_t i = 0; i < 4; ++i)
|
||||
{
|
||||
atomicAdd(&(p_dst_float[i]), p_src_float[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t DataPerAccess>
|
||||
struct SetData
|
||||
{
|
||||
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
|
||||
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
|
||||
}
|
||||
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
|
||||
__device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const
|
||||
{
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
index_t DataPerAccess,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace>
|
||||
__device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
template <typename T, index_t DataPerAccess>
|
||||
struct AtomicAddData
|
||||
{
|
||||
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
|
||||
|
||||
static_if<SrcAddressSpace == AddressSpace::Vgpr &&
|
||||
DstAddressSpace == AddressSpace::Global>{}([&](auto) {
|
||||
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
|
||||
}).Else([&](auto fwd) {
|
||||
static_assert(fwd(false), "atomic_add doesn't support this memory space");
|
||||
});
|
||||
}
|
||||
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
|
||||
__device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const
|
||||
{
|
||||
atomic_add_impl(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
index_t DataPerAccess,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace,
|
||||
InMemoryDataOperation DstInMemOp>
|
||||
InMemoryDataOperation DstInMemOp,
|
||||
index_t SrcDataStride = 1,
|
||||
index_t DstDataStride = 1>
|
||||
__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
{
|
||||
static_assert(DstInMemOp == InMemoryDataOperation::Set ||
|
||||
DstInMemOp == InMemoryDataOperation::AtomicAdd,
|
||||
"wrong! InMemoryDataOperation not supported!");
|
||||
|
||||
// TODO: use static_if::ElseIf
|
||||
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
|
||||
copy_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, p_dst, dst_offset);
|
||||
});
|
||||
// keep it simple, don't use static_if here, otherwise compiler will do weird things
|
||||
if(SrcDataStride == 1 && DstDataStride == 1)
|
||||
{
|
||||
// TODO: use static_if::ElseIf
|
||||
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
|
||||
SetData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, p_dst, dst_offset);
|
||||
});
|
||||
|
||||
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
|
||||
atomic_add_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, p_dst, dst_offset);
|
||||
});
|
||||
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
|
||||
AtomicAddData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, p_dst, dst_offset);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
for(index_t i = 0; i < DataPerAccess; i++)
|
||||
{
|
||||
// TODO: use static_if::ElseIf
|
||||
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
|
||||
SetData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride);
|
||||
});
|
||||
|
||||
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
|
||||
AtomicAddData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "config.hpp"
|
||||
#include "integral_constant.hpp"
|
||||
#include "number.hpp"
|
||||
#include "type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
25
composable_kernel/include/utility/synchronization.amd.hpp.in
Normal file
25
composable_kernel/include/utility/synchronization.amd.hpp.in
Normal file
@@ -0,0 +1,25 @@
|
||||
#ifndef CK_SYNCHRONIZATION_AMD_HPP
|
||||
#define CK_SYNCHRONIZATION_AMD_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
__device__ void __llvm_amdgcn_s_barrier() __asm("llvm.amdgcn.s.barrier");
|
||||
|
||||
__device__ void block_sync_lds()
|
||||
{
|
||||
#if CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
asm volatile("\
|
||||
s_waitcnt lgkmcnt(0) \n \
|
||||
s_barrier \
|
||||
" ::);
|
||||
#else
|
||||
__llvm_amdgcn_s_barrier();
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void block_sync_lds_vmem() { __llvm_amdgcn_s_barrier(); }
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,13 @@
|
||||
#ifndef CK_SYNCHRONIZATION_NVIDIA_HPP
|
||||
#define CK_SYNCHRONIZATION_NVIDIA_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
__device__ void block_sync_lds() { __syncthreads(); }
|
||||
|
||||
__device__ void block_sync_lds_vmem() { __syncthreads(); }
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
Reference in New Issue
Block a user