mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
remove dead code
This commit is contained in:
@@ -1,401 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// define B = merge(N0, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t GemmNRepeat,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
typename InBlockCopySubLengths_E_N1_B_N2,
|
||||
typename InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
typename InBlockCopyThreadClusterArrangeOrder,
|
||||
typename InBlockCopySrcAccessOrder,
|
||||
typename InBlockCopyDstAccessOrder,
|
||||
index_t InBlockCopySrcDataPerRead_B,
|
||||
index_t InBlockCopyDstDataPerWrite_N2,
|
||||
typename WeiBlockCopySubLengths_E_K,
|
||||
typename WeiBlockCopyClusterLengths_E_K,
|
||||
typename WeiBlockCopyThreadClusterArrangeOrder,
|
||||
typename WeiBlockCopySrcAccessOrder,
|
||||
typename WeiBlockCopyDstAccessOrder,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K>
|
||||
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// this is a mess
|
||||
// TODO: find more elegent way of specifying (or calculating) performance parameters
|
||||
constexpr index_t N1 = GemmNRepeat;
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
|
||||
static_assert((N1 * N2 * BPerBlock) %
|
||||
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
|
||||
|
||||
constexpr index_t N0 = N / (N1 * N2);
|
||||
|
||||
constexpr index_t B = N0 * Ho * Wo;
|
||||
|
||||
constexpr index_t E = C * Y * X;
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) &&
|
||||
(X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
|
||||
constexpr auto in_n0_n1_n2_h_w_global_desc =
|
||||
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrideH>{})
|
||||
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrideW>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{})
|
||||
.Extract(Sequence<0, 1, 2, 4, 5>{});
|
||||
|
||||
// batch descritpor for device memory
|
||||
constexpr auto in_c_y_x_global_desc =
|
||||
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
|
||||
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
|
||||
.Extract(Sequence<1, 2, 3>{});
|
||||
|
||||
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
|
||||
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_global_desc),
|
||||
Sequence<0, 1, 2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<3, 6, 7>{},
|
||||
Sequence<5>{});
|
||||
|
||||
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with multiple alignment
|
||||
// requirements
|
||||
static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not satisfied");
|
||||
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
|
||||
decltype(in_e_n1_b_n2_global_merged_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
2,
|
||||
3,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2>(
|
||||
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<EPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with multiple alignment
|
||||
// requirements
|
||||
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
|
||||
"GemmDataPerReadA alignment requirement is not satisfied");
|
||||
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
0,
|
||||
1,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[EPerBlock, KPerBlock] is in LDS
|
||||
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
|
||||
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
|
||||
|
||||
constexpr auto b_e_n1bn2_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(in_e_n1_b_n2_block_desc.Unfold(I1, I3));
|
||||
|
||||
// sanity check
|
||||
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_e_k_block_mtx_desc),
|
||||
decltype(b_e_n1bn2_block_mtx_desc),
|
||||
decltype(c_k0k2_n1n2_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS allocation for input and weight: be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t in_block_space =
|
||||
math::integer_least_multiple(in_e_n1_b_n2_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread);
|
||||
|
||||
// do work
|
||||
for(index_t e = 0; e < E; e += EPerBlock)
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global, p_in_block);
|
||||
blockwise_wei_copy.Run(p_wei_global, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_gemm.Run(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_in_copy.MoveSrcSliceWindow(make_multi_index(EPerBlock, 0, 0, 0), True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(make_multi_index(EPerBlock, 0), True);
|
||||
}
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
#if 0
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output memory layout descriptor in register
|
||||
constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc =
|
||||
make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{});
|
||||
|
||||
// output tensor descriptor in register, src of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc =
|
||||
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old(
|
||||
Sequence<4, 3, 7, 0, 1, 2, 5, 6>{});
|
||||
|
||||
// output memory layout descriptor in device memory, dst of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc =
|
||||
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
|
||||
|
||||
// output merged global tensor descriptor, for calculating origin of thread tensor
|
||||
// in global memory
|
||||
constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5),
|
||||
Sequence<3>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 4, 5>{},
|
||||
Sequence<2>{});
|
||||
|
||||
// origin of dst in device memory
|
||||
Float* p_out_thread_on_global =
|
||||
p_out_global +
|
||||
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v2r1<
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 8, 1>::type,
|
||||
arithmetic_sequence_gen<0, 8, 1>::type,
|
||||
7,
|
||||
7,
|
||||
1,
|
||||
1>({0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0})
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#else
|
||||
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output memory layout descriptor in register, src of threadwise copy
|
||||
constexpr auto out_k0_k1_n1_b_n2_thread_mem_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{});
|
||||
|
||||
// output memory layout descriptor in device memory
|
||||
constexpr auto out_n0_n1_n2_k0_k1_h_w_global_mem_desc =
|
||||
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}).Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
// output merged global tensor descriptor, dst of threadwise copy
|
||||
constexpr auto out_k0_k1_n1_b_n2_global_merged_desc =
|
||||
make_ConstantMergedTensorDescriptor(out_n0_n1_n2_k0_k1_h_w_global_mem_desc,
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 5, 6>{},
|
||||
Sequence<2>{});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v2r1<
|
||||
decltype(out_k0_k1_n1_b_n2_thread_mem_desc),
|
||||
decltype(out_k0_k1_n1_b_n2_global_merged_desc),
|
||||
decltype(out_k0_k1_n1_b_n2_thread_mem_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 5, 1>::type,
|
||||
arithmetic_sequence_gen<0, 5, 1>::type,
|
||||
3,
|
||||
3,
|
||||
1,
|
||||
1>({0, 0, 0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
0,
|
||||
b_thread_data_on_global,
|
||||
0})
|
||||
.template Run_amd_experiment<Float, 0, 2>(p_out_thread, p_out_global);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,530 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_PADDED_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_PADDED_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// define B = merge(N0, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t GemmNRepeat,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
typename InBlockCopySubLengths_E_N1_B_N2,
|
||||
typename InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
typename InBlockCopyThreadClusterArrangeOrder,
|
||||
typename InBlockCopySrcAccessOrder,
|
||||
typename InBlockCopyDstAccessOrder,
|
||||
index_t InBlockCopySrcDataPerRead_B,
|
||||
index_t InBlockCopyDstDataPerWrite_N2,
|
||||
typename WeiBlockCopySubLengths_E_K,
|
||||
typename WeiBlockCopyClusterLengths_E_K,
|
||||
typename WeiBlockCopyThreadClusterArrangeOrder,
|
||||
typename WeiBlockCopySrcAccessOrder,
|
||||
typename WeiBlockCopyDstAccessOrder,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K>
|
||||
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
|
||||
{
|
||||
#if 1
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// this is a mess
|
||||
// TODO: find more elegent way of specifying (or calculating) performance parameters
|
||||
constexpr index_t N1 = GemmNRepeat;
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
|
||||
static_assert((N1 * N2 * BPerBlock) %
|
||||
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_global_desc =
|
||||
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
|
||||
constexpr auto wei_k_c_y_x_global_desc =
|
||||
make_native_tensor_descriptor(WeiGlobalDesc::GetLengths(), WeiGlobalDesc::GetStrides());
|
||||
constexpr auto out_n_k_ho_wo_global_desc =
|
||||
make_native_tensor_descriptor(OutGlobalDesc::GetLengths(), OutGlobalDesc::GetStrides());
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
|
||||
|
||||
constexpr index_t N0 = N / (N1 * N2);
|
||||
|
||||
constexpr index_t B = N0 * Ho * Wo;
|
||||
|
||||
constexpr index_t E = C * Y * X;
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) &&
|
||||
(X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// global memory
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(Unmerge<Sequence<N0, N1, N2>>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
|
||||
|
||||
constexpr auto in_e_n1_b_n2_global_desc = transform_tensor_descriptor(
|
||||
in_n0_n1_n2_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{},
|
||||
PassThrough<N1>{},
|
||||
Merge<Sequence<N0, Ho, Wo>>{},
|
||||
PassThrough<N2>{}),
|
||||
make_tuple(Sequence<3, 4, 6>{}, Sequence<1>{}, Sequence<0, 5, 7>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_e_n1_b_n2_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with multiple alignment
|
||||
// requirements
|
||||
static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not satisfied");
|
||||
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(in_e_n1_b_n2_global_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
2,
|
||||
3,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2>(
|
||||
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
|
||||
|
||||
#if 0
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
transform_tensor_descriptor(wei_k_c_y_x_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{}, PassThrough<K>{}),
|
||||
make_tuple(Sequence<1, 2, 3>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
#else // hack
|
||||
constexpr auto wei_e_k_global_desc_old =
|
||||
WeiGlobalDesc::Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
|
||||
|
||||
constexpr auto wei_e_k_global_desc = make_native_tensor_descriptor(
|
||||
wei_e_k_global_desc_old.GetLengths(), wei_e_k_global_desc_old.GetStrides());
|
||||
#endif
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<EPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with multiple alignment
|
||||
// requirements
|
||||
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
|
||||
"GemmDataPerReadA alignment requirement is not satisfied");
|
||||
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
0,
|
||||
1,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[EPerBlock, KPerBlock] is in LDS
|
||||
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
|
||||
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
|
||||
|
||||
constexpr auto b_e_n1bn2_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
in_e_n1_b_n2_block_desc.GetLength(I0),
|
||||
in_e_n1_b_n2_block_desc.GetLength(I1) * in_e_n1_b_n2_block_desc.GetLength(I2) *
|
||||
in_e_n1_b_n2_block_desc.GetLength(I3),
|
||||
in_e_n1_b_n2_block_desc.GetStride(I0));
|
||||
|
||||
// sanity check
|
||||
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<N1 * N2>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_e_k_block_mtx_desc),
|
||||
decltype(b_e_n1bn2_block_mtx_desc),
|
||||
decltype(c_k0k2_n1n2_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS allocation for input and weight: be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t in_block_space =
|
||||
math::integer_least_multiple(in_e_n1_b_n2_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread);
|
||||
|
||||
// do work
|
||||
for(index_t e = 0; e < E; e += EPerBlock)
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global, p_in_block);
|
||||
blockwise_wei_copy.Run(p_wei_global, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_gemm.Run(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_in_copy.MoveSrcSliceWindow(make_multi_index(EPerBlock, 0, 0, 0), True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(make_multi_index(EPerBlock, 0), True);
|
||||
}
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
|
||||
static_assert(K % (K1 * K2) == 0, "wrong!");
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output memory layout descriptor in register
|
||||
constexpr auto out_k0_k1_k2_n1_n0_ho_wo_n2_thread_desc =
|
||||
make_native_tensor_descriptor_packed(
|
||||
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{});
|
||||
|
||||
// output tensor descriptor in register, src of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_ho_wo_thread_desc =
|
||||
reorder_tensor_descriptor_given_upper2lower(out_k0_k1_k2_n1_n0_ho_wo_n2_thread_desc,
|
||||
Sequence<4, 3, 7, 0, 1, 2, 5, 6>{});
|
||||
|
||||
// output memory layout descriptor in device memory, dst of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(Unmerge<Sequence<N / (N1 * N2), N1, N2>>{},
|
||||
Unmerge<Sequence<K / (K1 * K2), K1, K2>>{},
|
||||
PassThrough<Ho>{},
|
||||
PassThrough<Wo>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}, Sequence<6>{}, Sequence<7>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
|
||||
|
||||
// output merged global tensor descriptor, for calculating origin of thread tensor
|
||||
// in global memory
|
||||
constexpr auto out_n0_n1_n2_k_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(Unmerge<Sequence<N / (N1 * N2), N1, N2>>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<Ho>{},
|
||||
PassThrough<Wo>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}));
|
||||
|
||||
constexpr auto out_k_n1_b_n2_global_desc = transform_tensor_descriptor(
|
||||
out_n0_n1_n2_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<N1>{},
|
||||
Merge<Sequence<N0, Ho, Wo>>{},
|
||||
PassThrough<N2>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<1>{}, Sequence<0, 4, 5>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// origin of dst in device memory
|
||||
Float* p_out_thread_on_global =
|
||||
p_out_global +
|
||||
out_k_n1_b_n2_global_desc.CalculateOffset(
|
||||
{k_thread_data_on_global, 0, b_thread_data_on_global, 0});
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_ho_wo_thread_desc),
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_ho_wo_global_desc),
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_ho_wo_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 8, 1>::type,
|
||||
7,
|
||||
1,
|
||||
1>({0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0})
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
}
|
||||
}
|
||||
#else
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// this is a mess
|
||||
// TODO: find more elegent way of specifying (or calculating) performance parameters
|
||||
constexpr index_t N1 = GemmNRepeat;
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
|
||||
static_assert((N1 * N2 * BPerBlock) %
|
||||
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc =
|
||||
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
|
||||
constexpr auto wei_k_c_y_x_global_desc =
|
||||
make_native_tensor_descriptor(WeiGlobalDesc::GetLengths(), WeiGlobalDesc::GetStrides());
|
||||
constexpr auto out_n_k_h_w_global_desc =
|
||||
make_native_tensor_descriptor(OutGlobalDesc::GetLengths(), OutGlobalDesc::GetStrides());
|
||||
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Hi = in_n_c_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_n_c_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
|
||||
|
||||
constexpr index_t N0 = N / (N1 * N2);
|
||||
|
||||
constexpr index_t B = N0 * Ho * Wo;
|
||||
|
||||
constexpr index_t E = C * Y * X;
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert(ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1,
|
||||
"wrong! global vector load of input tensor is wrong");
|
||||
|
||||
static_assert((X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// input
|
||||
constexpr auto in_n_c_hi_wi_global_desc =
|
||||
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
|
||||
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(Unmerge<Sequence<N0, N1, N2>>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
|
||||
|
||||
constexpr auto in_e_n1_b_n2_global_desc = transform_tensor_descriptor(
|
||||
in_n0_n1_n2_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{},
|
||||
PassThrough<N1>{},
|
||||
Merge<Sequence<N0, Ho, Wo>>{},
|
||||
PassThrough<N2>{}),
|
||||
make_tuple(Sequence<3, 4, 6>{}, Sequence<1>{}, Sequence<0, 5, 7>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// weight
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
transform_tensor_descriptor(wei_k_c_y_x_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{}, PassThrough<K>{}),
|
||||
make_tuple(Sequence<1, 2, 3>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_tensor_descriptor("in_n_c_hi_wi_global_desc: ", in_n_c_hi_wi_global_desc);
|
||||
print_tensor_descriptor("in_n_c_hip_wip_global_desc: ", in_n_c_hip_wip_global_desc);
|
||||
print_tensor_descriptor("in_n0_n1_n2_c_y_ho_x_wo_global_desc: ",
|
||||
in_n0_n1_n2_c_y_ho_x_wo_global_desc);
|
||||
print_tensor_descriptor("in_e_n1_b_n2_global_desc: ", in_e_n1_b_n2_global_desc);
|
||||
|
||||
auto coord3 = make_tensor_coordinate_v2(in_e_n1_b_n2_global_desc, {1, 1, 1, 1});
|
||||
|
||||
auto idx3 = coord3.GetIndex();
|
||||
auto idx2 = coord3.GetLowerCoordinate().GetIndex();
|
||||
auto idx1 = coord3.GetLowerCoordinate().GetLowerCoordinate().GetIndex();
|
||||
auto idx0 =
|
||||
coord3.GetLowerCoordinate().GetLowerCoordinate().GetLowerCoordinate().GetIndex();
|
||||
|
||||
print_array("idx3: ", idx3);
|
||||
print_array("idx2: ", idx2);
|
||||
print_array("idx1: ", idx1);
|
||||
print_array("idx0: ", idx0);
|
||||
}
|
||||
#else
|
||||
index_t itmp = get_block_1d_id() + get_thread_local_1d_id();
|
||||
auto wei_coord1 = make_tensor_coordinate_v2(wei_e_k_global_desc, {itmp, itmp + 1});
|
||||
|
||||
auto step_sizes = make_multi_index(EPerBlock, 0);
|
||||
|
||||
wei_coord1 += step_sizes;
|
||||
|
||||
p_out_global[0] = wei_coord1.GetLowerCoordinate().GetIndex()[0];
|
||||
p_out_global[1] = wei_coord1.GetLowerCoordinate().GetIndex()[1];
|
||||
p_out_global[2] = wei_coord1.GetLowerCoordinate().GetIndex()[2];
|
||||
p_out_global[3] = wei_coord1.GetLowerCoordinate().GetIndex()[3];
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,331 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// B = merge(N, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopySubLengths_E_B,
|
||||
class InBlockCopyClusterLengths_E_B,
|
||||
class InBlockCopyThreadClusterArrangeOrder,
|
||||
class InBlockCopySrcAccessOrder,
|
||||
class InBlockCopyDstAccessOrder,
|
||||
index_t InBlockCopyDataPerAccess_B,
|
||||
class WeiBlockCopySubLengths_E_K,
|
||||
class WeiBlockCopyClusterLengths_E_K,
|
||||
class WeiBlockCopyThreadClusterArrangeOrder,
|
||||
class WeiBlockCopySrcAccessOrder,
|
||||
class WeiBlockCopyDstAccessOrder,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K,
|
||||
index_t OutThreadCopyDataPerAccess_B>
|
||||
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1];
|
||||
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t E = C * Y * X;
|
||||
constexpr index_t B = N * Ho * Wo;
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopyDataPerAccess_B == 1)) &&
|
||||
(X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// tensor descriptor in device memory [N, Ho, Wo]
|
||||
constexpr auto in_n_ho_wo_global_desc =
|
||||
in_n_c_h_w_global_desc.Extract(I0, I2, I3)
|
||||
.StridedSlice(I1, Number<Ho>{}, Number<ConvStrideH>{})
|
||||
.StridedSlice(I2, Number<Wo>{}, Number<ConvStrideW>{});
|
||||
|
||||
// batch descritpor for device memory
|
||||
constexpr auto in_c_y_x_global_desc =
|
||||
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
|
||||
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
|
||||
.Extract(Sequence<1, 2, 3>{});
|
||||
|
||||
// merged tensor descriptor in device memory [E, B], src of blockwise copy
|
||||
constexpr auto in_e_b_global_desc =
|
||||
make_ConstantMergedTensorDescriptor(in_c_y_x_global_desc.Embed(in_n_ho_wo_global_desc),
|
||||
Sequence<0, 1, 2>{},
|
||||
Sequence<3, 4, 5>{});
|
||||
|
||||
// memory layout descriptor in LDS [E, B], dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_e_b_block_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<EPerBlock, BPerBlock>{});
|
||||
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
|
||||
decltype(in_e_b_global_desc),
|
||||
decltype(in_e_b_block_desc),
|
||||
decltype(in_e_b_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_B,
|
||||
InBlockCopyClusterLengths_E_B,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
1,
|
||||
1,
|
||||
InBlockCopyDataPerAccess_B,
|
||||
InBlockCopyDataPerAccess_B>(
|
||||
{0, b_block_data_on_global}, {0, 0});
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<EPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with multiple alignment
|
||||
// requirements
|
||||
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
|
||||
"GemmDataPerReadA alignment requirement is not satisfied");
|
||||
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
0,
|
||||
1,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[EPerBlock, KPerBlock] is in LDS
|
||||
// b_mtx[EPerBlocl, BPerBlock] is in LDS
|
||||
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
|
||||
|
||||
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
|
||||
|
||||
// sanity check
|
||||
static_assert(
|
||||
KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0 &&
|
||||
BPerBlock % (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
|
||||
constexpr index_t GemmNRepeat =
|
||||
BPerBlock / (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k1_b0b1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_e_k_block_mtx_desc),
|
||||
decltype(b_e_b_block_mtx_desc),
|
||||
decltype(c_k0k1_b0b1_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS allocation for input and weight: be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_B,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t in_block_space =
|
||||
math::integer_least_multiple(in_e_b_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
Float p_out_thread[c_k0k1_b0b1_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread);
|
||||
|
||||
for(index_t e_block_data_begin = 0; e_block_data_begin < E; e_block_data_begin += EPerBlock)
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global, p_in_block);
|
||||
blockwise_wei_copy.Run(p_wei_global, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_gemm.Run(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
}
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output global descriptor, for calculating origin of thread tensor
|
||||
// in global memory
|
||||
constexpr auto out_k_b_global_desc = make_ConstantMergedTensorDescriptor(
|
||||
out_n_k_h_w_global_desc, Sequence<1>{}, Sequence<0, 2, 3>{});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col;
|
||||
|
||||
// This is a hack, because slicing a merged dimension is not supported yet.
|
||||
// This should be replaced with logic above, once slicing a merged dimension support
|
||||
// become available
|
||||
// dst descriptor
|
||||
constexpr auto out_k0_k1_b_global_desc =
|
||||
make_ConstantMergedTensorDescriptor(out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}),
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<0, 3, 4>{});
|
||||
|
||||
// src descriptor
|
||||
constexpr auto out_k0_k1_b_thread_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat * GemmNPerThreadSubC>{});
|
||||
|
||||
using OutThreadCopySliceLengths =
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>;
|
||||
|
||||
auto threadwise_out_copy =
|
||||
ThreadwiseGenericTensorSliceCopy_v2r1<decltype(out_k0_k1_b_thread_desc),
|
||||
decltype(out_k0_k1_b_global_desc),
|
||||
OutThreadCopySliceLengths,
|
||||
arithmetic_sequence_gen<0, 3, 1>::type,
|
||||
arithmetic_sequence_gen<0, 3, 1>::type,
|
||||
2,
|
||||
2,
|
||||
OutThreadCopyDataPerAccess_B,
|
||||
OutThreadCopyDataPerAccess_B>(
|
||||
{0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
b_thread_data_on_global});
|
||||
|
||||
for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat)
|
||||
{
|
||||
threadwise_out_copy.Run(p_out_thread, p_out_global);
|
||||
|
||||
threadwise_out_copy.MoveSrcSliceWindow(Sequence<0, 0, GemmNPerThreadSubC>{}, True);
|
||||
threadwise_out_copy.MoveDstSliceWindow(Sequence<0, 0, B1>{}, True);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,461 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_PADDED_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_PADDED_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// B = merge(N, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
typename InBlockCopySubLengths_E_B,
|
||||
typename InBlockCopyClusterLengths_E_B,
|
||||
typename InBlockCopyThreadClusterArrangeOrder,
|
||||
typename InBlockCopySrcAccessOrder,
|
||||
typename InBlockCopyDstAccessOrder,
|
||||
index_t InBlockCopyDataPerAccess_B,
|
||||
typename WeiBlockCopySubLengths_E_K,
|
||||
typename WeiBlockCopyClusterLengths_E_K,
|
||||
typename WeiBlockCopyThreadClusterArrangeOrder,
|
||||
typename WeiBlockCopySrcAccessOrder,
|
||||
typename WeiBlockCopyDstAccessOrder,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K,
|
||||
index_t OutThreadCopyDataPerAccess_B>
|
||||
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded
|
||||
{
|
||||
#if 1
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_global_desc =
|
||||
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
|
||||
constexpr auto wei_k_c_y_x_global_desc =
|
||||
make_native_tensor_descriptor(WeiGlobalDesc::GetLengths(), WeiGlobalDesc::GetStrides());
|
||||
constexpr auto out_n_k_ho_wo_global_desc =
|
||||
make_native_tensor_descriptor(OutGlobalDesc::GetLengths(), OutGlobalDesc::GetStrides());
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t E = C * Y * X;
|
||||
constexpr index_t B = N * Ho * Wo;
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopyDataPerAccess_B == 1)) &&
|
||||
(X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// global mem
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_e_b_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// LDS mem
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_e_b_block_desc =
|
||||
make_native_tensor_descriptor_packed(Sequence<EPerBlock, BPerBlock>{});
|
||||
|
||||
// input blockwise copy
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(in_e_b_global_desc),
|
||||
decltype(in_e_b_block_desc),
|
||||
decltype(in_e_b_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_B,
|
||||
InBlockCopyClusterLengths_E_B,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
1,
|
||||
1,
|
||||
InBlockCopyDataPerAccess_B,
|
||||
InBlockCopyDataPerAccess_B>(
|
||||
{0, b_block_data_on_global}, {0, 0});
|
||||
|
||||
// weight tensor
|
||||
// global mem
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
transform_tensor_descriptor(wei_k_c_y_x_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{}, PassThrough<K>{}),
|
||||
make_tuple(Sequence<1, 2, 3>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// LDS
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<EPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with multiple alignment
|
||||
// requirements
|
||||
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
|
||||
"GemmDataPerReadA alignment requirement is not satisfied");
|
||||
|
||||
// weight blockwise copy
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
0,
|
||||
1,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[EPerBlock, KPerBlock] is in LDS
|
||||
// b_mtx[EPerBlocl, BPerBlock] is in LDS
|
||||
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
|
||||
|
||||
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
|
||||
|
||||
// sanity check
|
||||
static_assert(
|
||||
KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0 &&
|
||||
BPerBlock % (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
|
||||
constexpr index_t GemmNRepeat =
|
||||
BPerBlock / (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k1_b0b1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_e_k_block_mtx_desc),
|
||||
decltype(b_e_b_block_mtx_desc),
|
||||
decltype(c_k0k1_b0b1_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS allocation for input and weight: be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_B,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t in_block_space =
|
||||
math::integer_least_multiple(in_e_b_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
Float p_out_thread[c_k0k1_b0b1_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread);
|
||||
|
||||
for(index_t e_block_data_begin = 0; e_block_data_begin < E; e_block_data_begin += EPerBlock)
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global, p_in_block);
|
||||
blockwise_wei_copy.Run(p_wei_global, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_gemm.Run(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_in_copy.MoveSrcSliceWindow(make_multi_index(EPerBlock, 0), True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(make_multi_index(EPerBlock, 0), True);
|
||||
}
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col;
|
||||
|
||||
// src descriptor
|
||||
constexpr auto out_k0_k1_b0_b1_thread_desc = make_native_tensor_descriptor_packed(
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat, GemmNPerThreadSubC>{});
|
||||
|
||||
// dst descriptor
|
||||
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
|
||||
constexpr index_t K0 = K / K1;
|
||||
constexpr index_t B0 = B / B1;
|
||||
|
||||
constexpr auto out_k_b_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
constexpr auto out_k0_k1_b0_b1_global_desc = transform_tensor_descriptor(
|
||||
out_k_b_global_desc,
|
||||
make_tuple(Unmerge<Sequence<K0, K1>>{}, Unmerge<Sequence<B0, B1>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
// output threadwise copy
|
||||
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v4r2<
|
||||
decltype(out_k0_k1_b0_b1_thread_desc),
|
||||
decltype(out_k0_k1_b0_b1_global_desc),
|
||||
decltype(out_k0_k1_b0_b1_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 4, 1>::type,
|
||||
3,
|
||||
OutThreadCopyDataPerAccess_B,
|
||||
OutThreadCopyDataPerAccess_B>({0, 0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
b_thread_data_on_global / B1,
|
||||
b_thread_data_on_global % B1});
|
||||
|
||||
threadwise_out_copy.Run(p_out_thread, p_out_global);
|
||||
}
|
||||
}
|
||||
#else
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_global_desc =
|
||||
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
|
||||
constexpr auto wei_k_c_y_x_global_desc =
|
||||
make_native_tensor_descriptor(WeiGlobalDesc::GetLengths(), WeiGlobalDesc::GetStrides());
|
||||
constexpr auto out_n_k_ho_wo_global_desc =
|
||||
make_native_tensor_descriptor(OutGlobalDesc::GetLengths(), OutGlobalDesc::GetStrides());
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t E = C * Y * X;
|
||||
constexpr index_t B = N * Ho * Wo;
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Ho == 1 || ConvStrideW % InBlockCopyDataPerAccess_B == 0) &&
|
||||
(X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// input tensor
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_e_b_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_k_b_global_desc =
|
||||
transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
|
||||
constexpr index_t K0 = K / K1;
|
||||
constexpr index_t B0 = B / B1;
|
||||
|
||||
constexpr auto out_k0_k1_b0_b1_global_desc = transform_tensor_descriptor(
|
||||
out_k_b_global_desc,
|
||||
make_tuple(Unmerge<Sequence<K0, K1>>{}, Unmerge<Sequence<B0, B1>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
#if 1
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_tensor_descriptor("in_e_b_global_desc: ", in_e_b_global_desc);
|
||||
print_tensor_descriptor("in_n_c_y_ho_x_wo_global_desc: ", in_n_c_y_ho_x_wo_global_desc);
|
||||
print_tensor_descriptor("in_n_c_hip_wip_global_desc: ", in_n_c_hip_wip_global_desc);
|
||||
print_tensor_descriptor("in_n_c_hi_wi_global_desc: ", in_n_c_hi_wi_global_desc);
|
||||
|
||||
auto coord3 = make_tensor_coordinate_v2(in_e_b_global_desc, {1, 1});
|
||||
|
||||
auto idx3 = coord3.GetIndex();
|
||||
auto idx2 = coord3.GetLowerCoordinate().GetIndex();
|
||||
auto idx1 = coord3.GetLowerCoordinate().GetLowerCoordinate().GetIndex();
|
||||
auto idx0 =
|
||||
coord3.GetLowerCoordinate().GetLowerCoordinate().GetLowerCoordinate().GetIndex();
|
||||
|
||||
print_array("idx3: ", idx3);
|
||||
print_array("idx2: ", idx2);
|
||||
print_array("idx1: ", idx1);
|
||||
print_array("idx0: ", idx0);
|
||||
}
|
||||
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_tensor_descriptor("out_k0_k1_b0_b1_global_desc: ", out_k0_k1_b0_b1_global_desc);
|
||||
print_tensor_descriptor("out_k_b_global_desc: ", out_k_b_global_desc);
|
||||
print_tensor_descriptor("out_n_k_ho_wo_global_desc: ", out_n_k_ho_wo_global_desc);
|
||||
|
||||
auto coord2 = make_tensor_coordinate_v2(out_k0_k1_b0_b1_global_desc, {1, 1, 1, 1});
|
||||
|
||||
auto idx2 = coord2.GetIndex();
|
||||
auto idx1 = coord2.GetLowerCoordinate().GetIndex();
|
||||
auto idx0 = coord2.GetLowerCoordinate().GetLowerCoordinate().GetIndex();
|
||||
|
||||
print_array("idx2: ", idx2);
|
||||
print_array("idx1: ", idx1);
|
||||
print_array("idx0: ", idx0);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,259 +0,0 @@
|
||||
#pragma once
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_direct_convolution.hpp"
|
||||
#include "threadwise_4d_tensor_op.hpp"
|
||||
#include "threadwise_direct_convolution.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class TInWei,
|
||||
class TOut,
|
||||
class TAccum,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t ScalarPerVector,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t CPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
index_t BlockSize,
|
||||
index_t GridSize>
|
||||
__global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
const typename vector_type<TInWei,
|
||||
ScalarPerVector>::MemoryType* const __restrict__ p_in_vec_global,
|
||||
const typename vector_type<TInWei,
|
||||
ScalarPerVector>::MemoryType* const __restrict__ p_wei_vec_global,
|
||||
TOut* const __restrict__ p_out_global)
|
||||
{
|
||||
using in_scalar_t = TInWei;
|
||||
using in_vector_mem_t = typename vector_type<in_scalar_t, ScalarPerVector>::MemoryType;
|
||||
using out_scalar_t = TOut;
|
||||
using accum_t = TAccum;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_nchw_vec_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_kcyx_vec_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_nkhw_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_nchw_vec_global_desc.GetLength(I0);
|
||||
constexpr index_t K = wei_kcyx_vec_global_desc.GetLength(I0);
|
||||
constexpr index_t C = wei_kcyx_vec_global_desc.GetLength(I1);
|
||||
constexpr index_t Y = wei_kcyx_vec_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_kcyx_vec_global_desc.GetLength(I3);
|
||||
|
||||
constexpr auto wei_ke_vec_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K, C * Y * X>{}); // 2d view of wei for blockwise copy
|
||||
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
constexpr auto in_nchw_vec_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr auto wei_ke_vec_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<KPerBlock, CPerBlock * Y * X>{},
|
||||
Number<WeiBlockCopyDataPerRead>{}); // 2d view of wei for blockwise copy
|
||||
|
||||
constexpr auto wei_kcyx_vec_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<KPerBlock, CPerBlock, Y, X>{},
|
||||
Sequence<wei_ke_vec_block_desc.GetStride(I0), Y * X, X, 1>{});
|
||||
|
||||
// shared mem
|
||||
constexpr index_t in_block_element_size =
|
||||
in_nchw_vec_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr index_t wei_block_element_size =
|
||||
wei_kcyx_vec_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
|
||||
? InBlockCopyDataPerRead
|
||||
: WeiBlockCopyDataPerRead;
|
||||
|
||||
__shared__ in_vector_mem_t
|
||||
p_in_vec_block[max_align * ((in_block_element_size + max_align - 1) / max_align)];
|
||||
__shared__ in_vector_mem_t
|
||||
p_wei_vec_block[max_align * ((wei_block_element_size + max_align - 1) / max_align)];
|
||||
|
||||
// threadwise tensors
|
||||
constexpr index_t HiPerThread = HoPerThread + Y - 1;
|
||||
constexpr index_t WiPerThread = WoPerThread + X - 1;
|
||||
|
||||
constexpr auto in_nchw_vec_thread_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{},
|
||||
in_nchw_vec_block_desc.GetStrides());
|
||||
|
||||
constexpr auto wei_kcyx_vec_thread_block_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, CPerThread, Y, X>{}, wei_kcyx_vec_block_desc.GetStrides());
|
||||
|
||||
constexpr auto out_nkhw_thread_desc = get_convolution_output_default_4d_tensor_descriptor(
|
||||
in_nchw_vec_thread_block_desc, wei_kcyx_vec_thread_block_desc);
|
||||
|
||||
// register
|
||||
out_scalar_t p_out_thread[out_nkhw_thread_desc.GetElementSpace()];
|
||||
|
||||
// divide block work
|
||||
constexpr index_t NBlockWork = (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
|
||||
constexpr index_t KBlockWork = (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
|
||||
|
||||
const index_t block_id = blockIdx.x;
|
||||
|
||||
index_t itmp = block_id;
|
||||
const index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
|
||||
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
|
||||
const index_t k_block_work_id = itmp / (HBlockWork * WBlockWork);
|
||||
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
|
||||
const index_t h_block_work_id = itmp / WBlockWork;
|
||||
const index_t w_block_work_id = itmp - h_block_work_id * WBlockWork;
|
||||
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin; // minus padding
|
||||
const index_t wi_block_data_begin = wo_block_data_begin; // minus padding
|
||||
|
||||
// divide thread work
|
||||
constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
|
||||
constexpr index_t KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread;
|
||||
constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
|
||||
constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
|
||||
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
|
||||
itmp = thread_id;
|
||||
const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
|
||||
itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork);
|
||||
const index_t k_thread_work_id = itmp / (HThreadWork * WThreadWork);
|
||||
itmp -= k_thread_work_id * (HThreadWork * WThreadWork);
|
||||
const index_t h_thread_work_id = itmp / WThreadWork;
|
||||
const index_t w_thread_work_id = itmp - h_thread_work_id * WThreadWork;
|
||||
|
||||
const index_t n_thread_data_begin = n_thread_work_id * NPerThread;
|
||||
const index_t k_thread_data_begin = k_thread_work_id * KPerThread;
|
||||
const index_t ho_thread_data_begin = h_thread_work_id * HoPerThread;
|
||||
const index_t wo_thread_data_begin = w_thread_work_id * WoPerThread;
|
||||
|
||||
const index_t hi_thread_data_begin = ho_thread_data_begin;
|
||||
const index_t wi_thread_data_begin = wo_thread_data_begin;
|
||||
|
||||
constexpr auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
in_vector_mem_t,
|
||||
decltype(in_nchw_vec_global_desc),
|
||||
decltype(in_nchw_vec_block_desc),
|
||||
decltype(in_nchw_vec_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead>{};
|
||||
|
||||
#if 0
|
||||
constexpr auto blockwise_wei_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
in_vector_mem_t,
|
||||
decltype(wei_kcyx_vec_global_desc),
|
||||
decltype(wei_kcyx_vec_block_desc),
|
||||
decltype(wei_kcyx_vec_block_desc.GetLengths()),
|
||||
1>{};
|
||||
#elif 1
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
in_vector_mem_t,
|
||||
decltype(wei_ke_vec_global_desc),
|
||||
decltype(wei_ke_vec_block_desc),
|
||||
decltype(wei_ke_vec_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
#if 1 // debug
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread);
|
||||
#endif
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
// copy input tensor to LDS
|
||||
blockwise_in_copy.Run(
|
||||
p_in_vec_global +
|
||||
in_nchw_vec_global_desc.GetOffsetFromMultiIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
p_in_vec_block);
|
||||
|
||||
// copy weight tensor to LDS
|
||||
blockwise_wei_copy.Run(p_wei_vec_global +
|
||||
wei_kcyx_vec_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
p_wei_vec_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(index_t c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
|
||||
{
|
||||
// threadwise convolution
|
||||
#if 1
|
||||
threadwise_direct_convolution_2(
|
||||
in_nchw_vec_thread_block_desc,
|
||||
p_in_vec_block +
|
||||
in_nchw_vec_block_desc.GetOffsetFromMultiIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_vec_thread_block_desc,
|
||||
p_wei_vec_block +
|
||||
wei_kcyx_vec_block_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_begin, c_thread_data, 0, 0),
|
||||
out_nkhw_thread_desc,
|
||||
p_out_thread);
|
||||
#elif 0
|
||||
threadwise_direct_convolution_3(
|
||||
in_nchw_vec_thread_block_desc,
|
||||
p_in_vec_block +
|
||||
in_nchw_vec_block_desc.GetOffsetFromMultiIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_vec_thread_block_desc,
|
||||
p_wei_vec_block +
|
||||
wei_kcyx_vec_block_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_begin, c_thread_data, 0, 0),
|
||||
out_nkhw_thread_desc,
|
||||
p_out_thread);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// copy output tensor from register to global mem
|
||||
threadwise_4d_tensor_copy(out_nkhw_thread_desc,
|
||||
p_out_thread,
|
||||
out_nkhw_global_desc,
|
||||
p_out_global +
|
||||
out_nkhw_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_nkhw_thread_desc.GetLengths());
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -1,298 +0,0 @@
|
||||
#pragma once
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "threadwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
class LowerPads,
|
||||
class UpperPads,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t CPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t WeiBlockCopyThreadPerDim0,
|
||||
index_t WeiBlockCopyThreadPerDim1>
|
||||
__global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
|
||||
const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global)
|
||||
{
|
||||
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
|
||||
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
|
||||
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
|
||||
static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0");
|
||||
static_assert((NPerThread < NPerBlock && WoPerThread == 1) || NPerThread == NPerBlock,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_chwn_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
|
||||
|
||||
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_khwn_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t HPadLow = LowerPads{}.Get(I0);
|
||||
constexpr index_t WPadLow = LowerPads{}.Get(I1);
|
||||
|
||||
constexpr index_t HPadUp = UpperPads{}.Get(I0);
|
||||
constexpr index_t WPadUp = UpperPads{}.Get(I1);
|
||||
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
|
||||
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
|
||||
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
|
||||
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
|
||||
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
|
||||
const index_t w_block_work_id = itmp / NBlockWork;
|
||||
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
|
||||
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
|
||||
// flattened (2d) tensor view of wei in global mem
|
||||
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
|
||||
|
||||
// tensor view of blockwise input and weight in LDS
|
||||
constexpr auto in_chwn_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{});
|
||||
|
||||
constexpr auto wei_cyxk_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<CPerBlock, Y, X, KPerBlock>{});
|
||||
|
||||
// flattened (2d) tensor view of wei in LDS
|
||||
constexpr auto wei_ek_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<CPerBlock * Y * X, KPerBlock>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_hkwn_thread_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<HoPerThread, KPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_cyxk_block_desc, "wei_cyxk_block_desc");
|
||||
print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
const index_t h_block_pad_low = h_block_work_id == 0 ? HPadLow : 0;
|
||||
const index_t w_block_pad_low = w_block_work_id == 0 ? WPadLow : 0;
|
||||
|
||||
const index_t h_block_pad_up = h_block_work_id == HBlockWork - 1 ? HPadUp : 0;
|
||||
const index_t w_block_pad_up = w_block_work_id == WBlockWork - 1 ? WPadUp : 0;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
;
|
||||
{
|
||||
printf(
|
||||
"%u %u, h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
h_block_pad_low,
|
||||
w_block_pad_low,
|
||||
h_block_pad_up,
|
||||
w_block_pad_up);
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr auto blockwise_in_copy =
|
||||
BlockwiseChwnTensorCopyPadded<BlockSize,
|
||||
Float,
|
||||
decltype(in_chwn_global_desc),
|
||||
decltype(in_chwn_block_desc),
|
||||
decltype(in_chwn_block_desc.GetLengths()),
|
||||
LowerPads>{};
|
||||
|
||||
#if 0
|
||||
// weight: format is [C,Y,X,K]
|
||||
constexpr auto blockwise_wei_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_cyxk_global_desc),
|
||||
decltype(wei_cyxk_block_desc),
|
||||
decltype(wei_cyxk_block_desc.GetLengths())>{};
|
||||
#elif 0
|
||||
// weight: format is [C*Y*X,K]
|
||||
constexpr auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths())>{};
|
||||
#elif 1
|
||||
// weight: format is [C*Y*X,K]
|
||||
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1>{};
|
||||
#endif
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N]
|
||||
constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_cyxk_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_cxwn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_chwn_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_kxwn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{}, Number<WoPerThread * NPerThread>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
Blockwise1dStridedBatchedGemmBlockABlockBThreadC<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxwn_block_mtx_desc),
|
||||
decltype(c_kxwn_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
in_chwn_block_desc.GetStride(I1),
|
||||
out_hkwn_thread_desc.GetStride(I0),
|
||||
HoPerBlock,
|
||||
HoPerThread,
|
||||
CPerThread,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
constexpr index_t in_block_element_size = in_chwn_block_desc.GetElementSpace();
|
||||
constexpr index_t wei_block_element_size = wei_cyxk_block_desc.GetElementSpace();
|
||||
|
||||
__shared__ Float p_in_block[in_block_element_size];
|
||||
__shared__ Float p_wei_block[wei_block_element_size];
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()];
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread);
|
||||
|
||||
const Float* p_wei_global_block_begin =
|
||||
p_wei_global + wei_ek_global_desc.GetOffsetFromMultiIndex(0, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_wei_global_block_begin += CPerBlock * wei_ek_global_desc.GetStride(I0),
|
||||
__syncthreads())
|
||||
{
|
||||
#if 1
|
||||
// input: global mem to LDS,
|
||||
blockwise_in_copy.Run(p_in_global,
|
||||
c_block_data_begin,
|
||||
ho_block_data_begin,
|
||||
wo_block_data_begin,
|
||||
n_block_data_begin,
|
||||
p_in_block,
|
||||
h_block_pad_low,
|
||||
w_block_pad_low,
|
||||
h_block_pad_up,
|
||||
w_block_pad_up);
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
// weight: global mem to LDS,
|
||||
blockwise_wei_copy.Run(p_wei_global_block_begin, p_wei_block);
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// a series of batched GEMM
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
|
||||
blockwise_batch_gemm.Run(
|
||||
p_wei_block + wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_in_block + in_chwn_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto matrix_c_index =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t ho_thread_data_begin = matrix_c_index.batch;
|
||||
const index_t k_thread_data_begin = matrix_c_index.row;
|
||||
const index_t wo_thread_data_begin = matrix_c_index.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
|
||||
|
||||
#if 0
|
||||
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
|
||||
get_block_1d_id(), get_thread_local_1d_id(),
|
||||
ho_block_data_begin, k_block_data_begin, wo_block_data_begin, n_block_data_begin,
|
||||
ho_thread_data_begin, k_thread_data_begin, wo_thread_data_begin, n_thread_data_begin,
|
||||
p_out_thread[0]);
|
||||
#endif
|
||||
|
||||
// output: register to global mem,
|
||||
// convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N]
|
||||
constexpr auto reorder_khwn_from_hkwn = Sequence<1, 0, 2, 3>{};
|
||||
|
||||
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(
|
||||
out_hkwn_thread_desc,
|
||||
p_out_thread,
|
||||
out_khwn_global_desc,
|
||||
p_out_global +
|
||||
out_khwn_global_desc.GetOffsetFromMultiIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_hkwn_thread_desc.GetLengths(),
|
||||
reorder_khwn_from_hkwn);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -5,12 +5,6 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t Length>
|
||||
struct Dimension
|
||||
{
|
||||
__host__ __device__ static constexpr auto GetLength() { return Number<Length>{}; }
|
||||
};
|
||||
|
||||
template <index_t Length, index_t Stride>
|
||||
struct NativeDimension
|
||||
{
|
||||
|
||||
@@ -193,7 +193,7 @@ struct TensorCoordinate
|
||||
private:
|
||||
template <typename... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>)
|
||||
{
|
||||
return NativeTensorCoordinate<NativeTensorDescriptor<Ts...>>(
|
||||
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
|
||||
@@ -201,7 +201,7 @@ struct TensorCoordinate
|
||||
|
||||
template <typename... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>)
|
||||
{
|
||||
return TransformedTensorCoordinate<TransformedTensorDescriptor<Ts...>>(
|
||||
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
|
||||
|
||||
@@ -326,14 +326,14 @@ struct TensorCoordinate_deprecated
|
||||
private:
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(ConstantTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(ConstantTensorDescriptor<Ts...>)
|
||||
{
|
||||
return NormalTensorCoordinate_deprecated<ConstantTensorDescriptor<Ts...>>();
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor<Ts...>)
|
||||
{
|
||||
return MergedTensorCoordinate<ConstantMergedTensorDescriptor<Ts...>>();
|
||||
}
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
#ifndef CK_TENSOR_VIEW_HPP
|
||||
#define CK_TENSOR_VIEW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "tensor_coordinate_deprecated.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// TensorDesc is ConstantTensorDescriptor or ConstantMergedTensorDescriptor
|
||||
template <class TensorDesc, class TData>
|
||||
struct NormalTensorView
|
||||
{
|
||||
using type = NormalTensorView;
|
||||
using tensor_desc_type = TensorDesc;
|
||||
using coordinate_type = typename NormalTensorCoordinate_deprecated<TensorDesc>::type;
|
||||
using data_type = TData;
|
||||
|
||||
static constexpr auto nDim = TensorDesc::GetNumOfDimension();
|
||||
|
||||
__host__ __device__ constexpr NormalTensorView(TData* p_data) : mpData{p_data} {}
|
||||
|
||||
__host__ __device__ constexpr NormalTensorView() : NormalTensorView{nullptr} {}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfDimension() { return nDim; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths() { return TensorDesc::GetLengths(); }
|
||||
|
||||
__host__ __device__ const TData& operator[](coordinate_type coord) const
|
||||
{
|
||||
return mpData[coord.GetOffset()];
|
||||
}
|
||||
|
||||
__host__ __device__ TData& operator()(coordinate_type coord) const
|
||||
{
|
||||
return mpData[coord.GetOffset()];
|
||||
}
|
||||
|
||||
template <class IDim, class DataPerVector>
|
||||
__host__ __device__ static constexpr auto IsVectorizationAllowed(IDim, DataPerVector)
|
||||
{
|
||||
return TensorDesc::IsVectorizationAllowed(IDim{}, DataPerVector{});
|
||||
}
|
||||
|
||||
template <class IDim, class DataPerVector>
|
||||
__host__ __device__ auto Vectorize(IDim idim, DataPerVector data_per_vector) const
|
||||
{
|
||||
static_assert(IsVectorizationAllowed(idim, data_per_vector), "wrong!");
|
||||
|
||||
using vector_t = typename vector_type<TData, data_per_vector>::MemoryType;
|
||||
return NormalTensorView<decltype(TensorDesc::Vectorize(idim, data_per_vector)), vector_t>(
|
||||
reinterpret_cast<vector_t*>(mpData));
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ auto Slice(coordinate_type slice_origin, Sequence<Is...> slice_lengths)
|
||||
{
|
||||
static_assert(slice_lengths.GetSize() == nDim, "wrong!");
|
||||
|
||||
return NormalTensorView<decltype(TensorDesc::Slice(slice_lengths)), TData>(
|
||||
mpData + slice_origin.GetOffset());
|
||||
}
|
||||
|
||||
template <class IDim, class SliceLen>
|
||||
__host__ __device__ auto
|
||||
Slice(coordinate_type slice_origin, IDim idim, SliceLen slice_len) const
|
||||
{
|
||||
return NormalTensorView<decltype(TensorDesc::Slice(idim, slice_len)), TData>(
|
||||
mpData + slice_origin.GetOffset());
|
||||
}
|
||||
|
||||
// slice_window is a slicing window on "*this"
|
||||
template <class SliceWindow, class T, bool PositiveDirection>
|
||||
__device__ void MoveSliceWindow(SliceWindow& slice_window,
|
||||
T step_sizes,
|
||||
integral_constant<bool, PositiveDirection>)
|
||||
{
|
||||
if(PositiveDirection)
|
||||
{
|
||||
slice_window.mpData += coordinate_type{step_sizes}.GetOffset();
|
||||
}
|
||||
else
|
||||
{
|
||||
slice_window.mpData -= coordinate_type{step_sizes}.GetOffset();
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
data_type* mpData;
|
||||
};
|
||||
|
||||
template <class... Xs, class TData>
|
||||
__host__ __device__ constexpr auto make_TensorView(ConstantTensorDescriptor<Xs...>, TData* p_data)
|
||||
{
|
||||
return NormalTensorView<ConstantTensorDescriptor<Xs...>, TData>{p_data};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,124 +0,0 @@
|
||||
#ifndef CK_TENSOR_VISIT_HPP
|
||||
#define CK_TENSOR_VISIT_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dimension.hpp"
|
||||
#include "dimension_transform.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_coordinate.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class TensorDescriptor>
|
||||
struct TensorVisit
|
||||
{
|
||||
using Index = typename TensorDescriptor::Index;
|
||||
using Coordinate = typename TensorCoordinate<TensorDescriptor>::type;
|
||||
|
||||
__host__ __device__ static void Run_v1(Index idx_begin)
|
||||
{
|
||||
const auto coord_begin = Coordinate(idx_begin);
|
||||
|
||||
ford<TensorDescriptor::GetLengths()>{}(
|
||||
[&](auto idx_diff) { index_t offset = (coord_begin + idx_diff).GetOffset(); });
|
||||
}
|
||||
|
||||
__host__ __device__ static void Run_v2(Index idx_begin)
|
||||
{
|
||||
const auto coord_begin = Coordinate(idx_begin);
|
||||
|
||||
ford<TensorDescriptor::GetLengths()>{}([&](auto idx_diff) {
|
||||
index_t offset_diff = coord_begin.GetOffsetDiff(idx_diff);
|
||||
index_t offset = coord_begin.GetOffset() + offset_diff;
|
||||
});
|
||||
}
|
||||
|
||||
__host__ __device__ static void Run_v3(Index idx_begin)
|
||||
{
|
||||
const auto coord_begin = Coordinate(idx_begin);
|
||||
|
||||
constexpr auto linear_dimensions = TensorDescriptor::GetLinearDimensions();
|
||||
constexpr auto nonlinear_dimensions = TensorDescriptor::GetNonLinearDimensions();
|
||||
|
||||
constexpr auto lengths = TensorDescriptor::GetLengths();
|
||||
|
||||
constexpr auto linear_dimension_lengths_hack =
|
||||
lambda_HackLengths{}(lengths, linear_dimensions);
|
||||
constexpr auto nonlinear_dimension_lengths_hack =
|
||||
lambda_HackLengths{}(lengths, nonlinear_dimensions);
|
||||
|
||||
ford<nonlinear_dimension_lengths_hack>{}([&](auto idx_diff_nonlinear_hack) {
|
||||
// run-time component
|
||||
index_t offset_diff_nonlinear = coord_begin.GetOffsetDiff(idx_diff_nonlinear_hack);
|
||||
|
||||
ford<linear_dimension_lengths_hack>{}([&](auto idx_diff_linear_hack) {
|
||||
// compile-time component
|
||||
index_t offset_diff_linear = coord_begin.GetOffsetDiff(idx_diff_linear_hack);
|
||||
|
||||
index_t offset =
|
||||
coord_begin.GetOffset() + offset_diff_nonlinear + offset_diff_linear;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
__host__ __device__ static void Run_v4(Index idx_begin)
|
||||
{
|
||||
const auto coord_begin = Coordinate(idx_begin);
|
||||
|
||||
constexpr auto linear_dimensions = TensorDescriptor::GetLinearDimensions();
|
||||
|
||||
constexpr auto nonlinear_independent_dimension_groups =
|
||||
TensorDescriptor::GetNonLinearIndependentDimensionGroups();
|
||||
|
||||
constexpr auto lengths = TensorDescriptor::GetLengths();
|
||||
|
||||
constexpr auto linear_dimension_lengths = lambda_HackLengths{}(lengths, linear_dimensions);
|
||||
|
||||
// run-time component
|
||||
index_t offset_diff_nonlinear = 0;
|
||||
|
||||
template <index_t NGroup>
|
||||
struct f_recursion
|
||||
{
|
||||
template <index_t IGroup>
|
||||
__host__ __device__ void Run(Number<IGroup>)
|
||||
{
|
||||
constexpr auto nonlinear_independent_dimensions_igroup =
|
||||
nonlinear_independent_dimension_groups.Get(igroup);
|
||||
|
||||
constexpr auto nonlinear_independent_lengths_igroup =
|
||||
lambda_HackLengths{}(lengths, nonlinear_independent_dimensions_igroup);
|
||||
|
||||
ford<nonlinear_independent_lengths_igroup>{}(
|
||||
[&](auto idx_diff_nonlinear_igroup_hack) {
|
||||
// run-time component
|
||||
offset_diff_nonlinear +=
|
||||
coord_begin.GetOffsetDiff(idx_diff_nonlinear_igroup_hack);
|
||||
|
||||
Run(Number<IGroup + 1>{});
|
||||
});
|
||||
};
|
||||
|
||||
// inner-most work
|
||||
template <>
|
||||
__host__ __device__ void Run(Number<NGroup>)
|
||||
{
|
||||
ford<linear_dimension_lengths>{}([&](auto idx_diff_linear_hack) {
|
||||
// compile-time component
|
||||
index_t offset_diff_linear = coord_begin.GetOffsetDiff(idx_diff_linear_hack);
|
||||
|
||||
index_t offset =
|
||||
coord_begin.GetOffset() + offset_diff_nonlinear + offset_diff_linear;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// run-time component
|
||||
index_t offset_diff_nonlinear = 0;
|
||||
|
||||
f_recursion<nonlinear_independent_dimension_groups.GetSize()>{}.Run();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,806 +0,0 @@
|
||||
#ifndef CK_BLOCKWISE_2D_TENSOR_OP_HPP
|
||||
#define CK_BLOCKWISE_2D_TENSOR_OP_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize, class Float, class DstDesc, class F>
|
||||
__device__ void
|
||||
blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr auto desc = make_ConstantTensorDescriptor(dst_desc.GetLengths());
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
|
||||
print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
|
||||
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
|
||||
|
||||
const index_t did0 = is / desc.GetStride(I0);
|
||||
|
||||
is -= did0 * desc.GetStride(I0);
|
||||
|
||||
const index_t did1 = is / desc.GetStride(I1);
|
||||
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
f(p_dst[dindex]);
|
||||
}
|
||||
|
||||
constexpr bool has_tail = (desc.GetElementSize() > NLoop * BlockSize);
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
|
||||
|
||||
if(is < desc.GetElementSize())
|
||||
{
|
||||
const index_t did0 = is / desc.GetStride(I0);
|
||||
|
||||
is -= did0 * desc.GetStride(I0);
|
||||
|
||||
const index_t did1 = is / desc.GetStride(I1);
|
||||
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
f(p_dst[dindex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Function: p_dst[reorder[i0], reorder[i1] = p_src[i0,i1]
|
||||
// TODO: in order to optimize mem access for different mem type,
|
||||
// need to write specialized version
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src,
|
||||
class F>
|
||||
__device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src,
|
||||
F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr index_t IR0 = MapDst2Src{}.Get(I0);
|
||||
constexpr index_t IR1 = MapDst2Src{}.Get(I1);
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
|
||||
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
|
||||
|
||||
index_t did[2];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
const index_t aindex = src_desc.GetOffsetFromMultiIndex(did[0], did[1]);
|
||||
|
||||
const index_t bindex = dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1]);
|
||||
|
||||
f(p_src[aindex], p_dst[bindex]);
|
||||
}
|
||||
|
||||
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
|
||||
|
||||
if(is < ref_desc.GetElementSize())
|
||||
{
|
||||
index_t did[2];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
const index_t aindex = src_desc.GetOffsetFromMultiIndex(did[0], did[1]);
|
||||
|
||||
const index_t bindex = dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1]);
|
||||
|
||||
f(p_src[aindex], p_dst[bindex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t BlockSize, class Float, class DstDesc>
|
||||
__device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
|
||||
{
|
||||
auto f_set_zero = [](Float& v) { v = Float(0); };
|
||||
|
||||
blockwise_2d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src>
|
||||
__device__ void
|
||||
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src)
|
||||
{
|
||||
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
|
||||
|
||||
blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class CopyLengths,
|
||||
index_t DataPerRead>
|
||||
struct Blockwise2dTensorCopy1
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
__device__ constexpr Blockwise2dTensorCopy1()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
static_assert(DataPerRead == 1 ||
|
||||
(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
|
||||
"wrong! only support stride1 == 1 if DataPerRead > 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I0) % DataPerRead == 0,
|
||||
"src and dst stride2 should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
// we allow out-of-bound read from src in D1 dimension,
|
||||
// but we need to make sure dst stride0 is big enough,
|
||||
// so that the out-of-bound write won't contaminate next line in dst
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t read_per_d1 = math::integer_divide_ceil(L1, DataPerRead);
|
||||
|
||||
static_assert(read_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
|
||||
"wrong! out-of-bound write will contaminate next line!\n");
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t read_per_d1 = math::integer_divide_ceil(L1, DataPerRead);
|
||||
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence<L0, read_per_d1>{});
|
||||
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
auto f_copy = [&](index_t is) {
|
||||
index_t did[4];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
const index_t src_index =
|
||||
src_desc.GetOffsetFromMultiIndex(did[0], did[1] * DataPerRead);
|
||||
const index_t dst_index =
|
||||
dst_desc.GetOffsetFromMultiIndex(did[0], did[1] * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const vector_t*>(p_src + src_index));
|
||||
};
|
||||
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
|
||||
|
||||
f_copy(is);
|
||||
}
|
||||
|
||||
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
|
||||
|
||||
if(is < ref_desc.GetElementSize())
|
||||
{
|
||||
f_copy(is);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// need to be aligned to float4 and float2
|
||||
// stride1 need to be 1 for both source and destination
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
index_t ThreadPerDim0,
|
||||
index_t ThreadPerDim1>
|
||||
struct Blockwise2dTensorCopy2
|
||||
{
|
||||
index_t mThreadId0;
|
||||
index_t mThreadId1;
|
||||
|
||||
__device__ Blockwise2dTensorCopy2()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1,
|
||||
"wrong! stride is not 1!\n");
|
||||
|
||||
mThreadId0 = get_thread_local_1d_id() / ThreadPerDim1;
|
||||
mThreadId1 = get_thread_local_1d_id() - mThreadId0 * ThreadPerDim1;
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
static_assert(is_same<Float, float>{}, "wrong! only support float!\n");
|
||||
|
||||
using Float4 = float4;
|
||||
using Float2 = float2;
|
||||
|
||||
if(get_thread_local_1d_id() >= ThreadPerDim0 * ThreadPerDim1)
|
||||
return;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
// check alignment
|
||||
constexpr bool align_v4 =
|
||||
src_desc.GetStride(I0) % 4 == 0 && dst_desc.GetStride(I0) % 4 == 0;
|
||||
|
||||
constexpr bool align_v2 =
|
||||
src_desc.GetStride(I0) % 2 == 0 && dst_desc.GetStride(I0) % 2 == 0;
|
||||
|
||||
constexpr index_t L0 = SrcOpLengths{}.Get(I0);
|
||||
constexpr index_t L1 = SrcOpLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t Dim0Loop = L0 / ThreadPerDim0;
|
||||
constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop);
|
||||
|
||||
constexpr index_t Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0;
|
||||
|
||||
constexpr index_t Dim1V2Loop =
|
||||
align_v2 ? (L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2) : 0;
|
||||
|
||||
constexpr index_t Dim1V1Loop =
|
||||
(L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) /
|
||||
ThreadPerDim1;
|
||||
|
||||
constexpr bool d1_has_tail =
|
||||
(L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop));
|
||||
|
||||
for(index_t d0loop = 0; d0loop < Dim0Loop; ++d0loop)
|
||||
{
|
||||
index_t did0 = d0loop * ThreadPerDim0 + mThreadId0;
|
||||
|
||||
// v4
|
||||
for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
|
||||
{
|
||||
index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
|
||||
|
||||
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<Float4*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<const Float4*>(p_src + sindex));
|
||||
}
|
||||
|
||||
// v2
|
||||
for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
|
||||
{
|
||||
index_t did1 =
|
||||
Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1;
|
||||
|
||||
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<Float2*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<const Float2*>(p_src + sindex));
|
||||
}
|
||||
|
||||
// v1
|
||||
for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
|
||||
{
|
||||
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
|
||||
d1v1loop * ThreadPerDim1 + mThreadId1;
|
||||
|
||||
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
|
||||
// dim-1 tail
|
||||
if(d1_has_tail)
|
||||
{
|
||||
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
|
||||
Dim1V1Loop * ThreadPerDim1 + mThreadId1;
|
||||
|
||||
if(did1 < L1)
|
||||
{
|
||||
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dim-0 tail
|
||||
if(d0_has_tail)
|
||||
{
|
||||
index_t did0 = Dim0Loop * ThreadPerDim0 + mThreadId0;
|
||||
|
||||
if(did0 < L0)
|
||||
{
|
||||
|
||||
// v4
|
||||
for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
|
||||
{
|
||||
index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
|
||||
|
||||
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<Float4*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<const Float4*>(p_src + sindex));
|
||||
}
|
||||
|
||||
// v2
|
||||
for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
|
||||
{
|
||||
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
|
||||
2 * mThreadId1;
|
||||
|
||||
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<Float2*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<const Float2*>(p_src + sindex));
|
||||
}
|
||||
|
||||
// v1
|
||||
for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
|
||||
{
|
||||
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
|
||||
d1v1loop * ThreadPerDim1 + mThreadId1;
|
||||
|
||||
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
|
||||
// tail
|
||||
if(d1_has_tail)
|
||||
{
|
||||
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
|
||||
Dim1V1Loop * ThreadPerDim1 + mThreadId1;
|
||||
|
||||
if(did1 < L1)
|
||||
{
|
||||
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// starting point need to be aligned to float4 or float2 or float
|
||||
// stride1 need to be 1 for both source and destination
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class CopyLengths,
|
||||
index_t DataPerRead>
|
||||
struct Blockwise2dTensorCopy3
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
index_t mSrcMyThreadOffset;
|
||||
index_t mDstMyThreadOffset;
|
||||
|
||||
__device__ Blockwise2dTensorCopy3(Array<index_t, 2> src_block_data_multi_id_begin,
|
||||
Array<index_t, 2> dst_block_data_multi_id_begin)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
static_assert(DataPerRead == 1 ||
|
||||
(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
|
||||
"wrong! only support stride1 == 1 if DataPerRead > 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I0) % DataPerRead == 0,
|
||||
"src and dst stride should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
// we allow out-of-bound read from src in D1 dimension,
|
||||
// but we need to make sure dst stride is big enough,
|
||||
// so that the out-of-bound write won't contaminate next line in dst
|
||||
static_assert(thread_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
|
||||
"wrong! out-of-bound write will contaminate next line!\n");
|
||||
|
||||
static_assert(thread_per_d0 >= 1, "wrong! not enough threads to cover one line\n");
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const index_t thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
|
||||
const index_t thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;
|
||||
|
||||
mSrcMyThreadOffset = SrcDesc{}.GetOffsetFromMultiIndex(
|
||||
src_block_data_multi_id_begin +
|
||||
Array<index_t, 2>{thread_id_d0, thread_id_d1 * DataPerRead});
|
||||
|
||||
mDstMyThreadOffset = DstDesc{}.GetOffsetFromMultiIndex(
|
||||
dst_block_data_multi_id_begin +
|
||||
Array<index_t, 2>{thread_id_d0, thread_id_d1 * DataPerRead});
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
|
||||
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
|
||||
auto f_copy = [&](index_t iloop) {
|
||||
*(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
|
||||
*(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset +
|
||||
iloop * src_loop_stride));
|
||||
};
|
||||
|
||||
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
|
||||
{
|
||||
f_copy(iloop);
|
||||
}
|
||||
|
||||
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
|
||||
|
||||
if(has_tail_d0)
|
||||
{
|
||||
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
|
||||
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
|
||||
{
|
||||
f_copy(nloop_d0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ constexpr index_t GetRegisterBufferSize() const
|
||||
{
|
||||
static_assert(is_same<Float, float>{}, "wrong! only support float!\n");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0;
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterBuffer(const Float* __restrict__ p_src,
|
||||
Float* __restrict__ p_clipboard) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
|
||||
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
|
||||
auto f_copy = [&](index_t iloop) {
|
||||
*(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
|
||||
*(reinterpret_cast<const vector_t*>(
|
||||
&p_src[mSrcMyThreadOffset + iloop * src_loop_stride]));
|
||||
};
|
||||
|
||||
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
|
||||
{
|
||||
f_copy(iloop);
|
||||
}
|
||||
|
||||
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
|
||||
|
||||
if(has_tail_d0)
|
||||
{
|
||||
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
|
||||
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
|
||||
{
|
||||
f_copy(nloop_d0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterBuffer(const Float* __restrict__ p_clipboard,
|
||||
Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
|
||||
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
|
||||
auto f_copy = [&](index_t iloop) {
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride])) =
|
||||
*(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]));
|
||||
};
|
||||
|
||||
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
|
||||
{
|
||||
f_copy(iloop);
|
||||
}
|
||||
|
||||
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
|
||||
|
||||
if(has_tail_d0)
|
||||
{
|
||||
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
|
||||
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
|
||||
{
|
||||
f_copy(nloop_d0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if CK_USE_AMD_INLINE_ASM
|
||||
__device__ void RunLoadRegisterBuffer_asm(const Float* __restrict__ p_src,
|
||||
Float* p_clipboard) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
|
||||
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
|
||||
auto f_copy = [&](index_t iloop) {
|
||||
#if 0
|
||||
*(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
|
||||
*(reinterpret_cast<const vector_t*>(&p_src[mSrcMyThreadOffset +
|
||||
iloop * src_loop_stride]));
|
||||
#else
|
||||
static_assert(is_same<float, Float>{} && DataPerRead == 4,
|
||||
"global_load is only for float4");
|
||||
|
||||
global_load(reinterpret_cast<vector_t&>(p_clipboard[iloop * DataPerRead]),
|
||||
reinterpret_cast<const vector_t*>(
|
||||
&p_src[mSrcMyThreadOffset + iloop * src_loop_stride]));
|
||||
#endif
|
||||
};
|
||||
|
||||
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
|
||||
{
|
||||
f_copy(iloop);
|
||||
}
|
||||
|
||||
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
|
||||
|
||||
if(has_tail_d0)
|
||||
{
|
||||
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
|
||||
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
|
||||
{
|
||||
f_copy(nloop_d0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterBuffer_asm(const Float* __restrict__ p_clipboard,
|
||||
Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
|
||||
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
|
||||
auto f_copy = [&](index_t iloop) {
|
||||
#if 0
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]) =
|
||||
*(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]);
|
||||
#else
|
||||
static_assert(is_same<float, Float>{} && DataPerRead == 4,
|
||||
"ds_write_b128 is only for float4");
|
||||
|
||||
ds_write_b128(reinterpret_cast<const vector_t&>(p_clipboard[iloop * DataPerRead]),
|
||||
&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]);
|
||||
#endif
|
||||
};
|
||||
|
||||
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
|
||||
{
|
||||
f_copy(iloop);
|
||||
}
|
||||
|
||||
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
|
||||
|
||||
if(has_tail_d0)
|
||||
{
|
||||
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
|
||||
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
|
||||
{
|
||||
f_copy(nloop_d0);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -1,378 +0,0 @@
|
||||
#ifndef CK_BLOCKWISE_3D_TENSOR_OP_HPP
|
||||
#define CK_BLOCKWISE_3D_TENSOR_OP_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class CopyLengths,
|
||||
index_t DataPerRead>
|
||||
struct Blockwise3dTensorCopy1
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
__device__ constexpr Blockwise3dTensorCopy1()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
static_assert(DataPerRead == 1 ||
|
||||
(SrcDesc{}.GetStride(I2) == 1 && DstDesc{}.GetStride(I2) == 1),
|
||||
"wrong! only support stride2 == 1 if DataPerRead > 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I1) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I1) % DataPerRead == 0,
|
||||
"src and dst stride1 should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
// we allow out-of-bound read from src in D3 dimension,
|
||||
// but we need to make sure dst stride2 is big enough,
|
||||
// so that the out-of-bound write won't contaminate next line in dst
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
constexpr index_t read_per_d2 = math::integer_divide_ceil(L2, DataPerRead);
|
||||
|
||||
static_assert(read_per_d2 * DataPerRead <= DstDesc{}.GetStride(I1),
|
||||
"wrong! out-of-bound write will contaminate next line!\n");
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
|
||||
constexpr index_t read_per_d2 = math::integer_divide_ceil(L2, DataPerRead);
|
||||
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence<L0, L1, read_per_d2>{});
|
||||
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
auto f_copy = [&](index_t is) {
|
||||
index_t did[3];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
is -= did[1] * ref_desc.GetStride(I1);
|
||||
|
||||
did[2] = is / ref_desc.GetStride(I2);
|
||||
|
||||
const index_t src_index =
|
||||
src_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2] * DataPerRead);
|
||||
const index_t dst_index =
|
||||
dst_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2] * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const vector_t*>(p_src + src_index));
|
||||
};
|
||||
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
|
||||
|
||||
f_copy(is);
|
||||
}
|
||||
|
||||
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
|
||||
|
||||
if(is < ref_desc.GetElementSize())
|
||||
{
|
||||
f_copy(is);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// starting point need to be aligned to float4 or float2 or float
|
||||
// stride3 need to be 1 for both source and destination
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class CopyLengths,
|
||||
class ThreadPerDims,
|
||||
index_t DataPerRead>
|
||||
struct Blockwise3dTensorCopy3
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
index_t mSrcMyThreadOffset;
|
||||
index_t mDstMyThreadOffset;
|
||||
|
||||
__device__ Blockwise3dTensorCopy3()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
static_assert(DataPerRead == 1 ||
|
||||
(SrcDesc{}.GetStride(I2) == 1 && DstDesc{}.GetStride(I2) == 1),
|
||||
"wrong! only support stride3 == 1 if DataPerRead > 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(
|
||||
SrcDesc{}.GetStride(I1) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I1) % DataPerRead == 0,
|
||||
"wrong! src and dst stride1 should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
|
||||
// we allow out-of-bound read from src in D2 dimension,
|
||||
// but we need to make sure dst stride is big enough,
|
||||
// so that the out-of-bound write won't contaminate next line in dst
|
||||
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
|
||||
|
||||
static_assert(nloop_d2 * thread_per_d2 * DataPerRead <= DstDesc{}.GetStride(I1),
|
||||
"wrong! out-of-bound write will contaminate next line!\n");
|
||||
|
||||
static_assert(L0 % thread_per_d0 == 0 && L1 % thread_per_d1 == 0,
|
||||
"wrong! L0, L1, L2 should be divided evenly!\n");
|
||||
|
||||
static_assert(BlockSize >= thread_per_d0 * thread_per_d1 * thread_per_d2,
|
||||
"wrrong! BlockSize is not big enough for ThreadPerDims!");
|
||||
|
||||
constexpr index_t num_active_thread =
|
||||
reduce_on_sequence(ThreadPerDims{}, math::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor(ThreadPerDims{});
|
||||
const auto thread_multi_id =
|
||||
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
|
||||
|
||||
mSrcMyThreadOffset = SrcDesc{}.GetOffsetFromMultiIndex(
|
||||
thread_multi_id[0], thread_multi_id[1], thread_multi_id[2] * DataPerRead);
|
||||
|
||||
mDstMyThreadOffset = DstDesc{}.GetOffsetFromMultiIndex(
|
||||
thread_multi_id[0], thread_multi_id[1], thread_multi_id[2] * DataPerRead);
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
|
||||
|
||||
#pragma unroll
|
||||
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
|
||||
{
|
||||
const index_t src_offset =
|
||||
SrcDesc{}.GetOffsetFromMultiIndex(iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2 * DataPerRead);
|
||||
|
||||
const index_t dst_offset =
|
||||
DstDesc{}.GetOffsetFromMultiIndex(iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_offset + mDstMyThreadOffset])) = *(
|
||||
reinterpret_cast<const vector_t*>(&p_src[src_offset + mSrcMyThreadOffset]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegisterBufferSize()
|
||||
{
|
||||
static_assert(is_same<Float, float>{}, "wrong! only support float!\n");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
|
||||
|
||||
return DataPerRead * nloop_d0 * nloop_d1 * nloop_d2;
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterBuffer(const Float* __restrict__ p_src,
|
||||
Float* __restrict__ p_clipboard) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
|
||||
|
||||
constexpr auto clipboard_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<nloop_d0, nloop_d1, nloop_d2 * DataPerRead>{});
|
||||
|
||||
#pragma unroll
|
||||
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
|
||||
{
|
||||
const index_t src_offset =
|
||||
SrcDesc{}.GetOffsetFromMultiIndex(iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2 * DataPerRead);
|
||||
|
||||
const index_t clipboard_offset = clipboard_desc.GetOffsetFromMultiIndex(
|
||||
iloop_d0, iloop_d1, iloop_d2 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_clipboard[clipboard_offset])) = *(
|
||||
reinterpret_cast<const vector_t*>(&p_src[src_offset + mSrcMyThreadOffset]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterBuffer(const Float* __restrict__ p_clipboard,
|
||||
Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
|
||||
|
||||
constexpr auto clipboard_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<nloop_d0, nloop_d1, nloop_d2 * DataPerRead>{});
|
||||
|
||||
#pragma unroll
|
||||
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
|
||||
{
|
||||
const index_t clipboard_offset = clipboard_desc.GetOffsetFromMultiIndex(
|
||||
iloop_d0, iloop_d1, iloop_d2 * DataPerRead);
|
||||
|
||||
const index_t dst_offset =
|
||||
DstDesc{}.GetOffsetFromMultiIndex(iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_offset + mDstMyThreadOffset])) =
|
||||
*(reinterpret_cast<const vector_t*>(&p_clipboard[clipboard_offset]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -1,779 +0,0 @@
|
||||
#ifndef CK_BLOCKWISE_4D_TENSOR_OP_HPP
|
||||
#define CK_BLOCKWISE_4D_TENSOR_OP_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize, class Float, class DstDesc, class F>
|
||||
__device__ void
|
||||
blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr auto desc = make_ConstantTensorDescriptor_packed(dst_desc.GetLengths());
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
|
||||
print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
|
||||
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
|
||||
|
||||
const index_t did0 = is / desc.GetStride(I0);
|
||||
|
||||
is -= did0 * desc.GetStride(I0);
|
||||
|
||||
const index_t did1 = is / desc.GetStride(I1);
|
||||
|
||||
is -= did1 * desc.GetStride(I1);
|
||||
|
||||
const index_t did2 = is / desc.GetStride(I2);
|
||||
|
||||
is -= did2 * desc.GetStride(I2);
|
||||
|
||||
const index_t did3 = is / desc.GetStride(I3);
|
||||
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1, did2, did3);
|
||||
|
||||
f(p_dst[dindex]);
|
||||
}
|
||||
|
||||
constexpr bool has_tail = (desc.GetElementSize() > NLoop * BlockSize);
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
|
||||
|
||||
if(is < desc.GetElementSize())
|
||||
{
|
||||
const index_t did0 = is / desc.GetStride(I0);
|
||||
|
||||
is -= did0 * desc.GetStride(I0);
|
||||
|
||||
const index_t did1 = is / desc.GetStride(I1);
|
||||
|
||||
is -= did1 * desc.GetStride(I1);
|
||||
|
||||
const index_t did2 = is / desc.GetStride(I2);
|
||||
|
||||
is -= did2 * desc.GetStride(I2);
|
||||
|
||||
const index_t did3 = is / desc.GetStride(I3);
|
||||
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1, did2, did3);
|
||||
|
||||
f(p_dst[dindex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
|
||||
// TODO: in order to optimize mem access for different mem type,
|
||||
// need to write specialized version
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src,
|
||||
class F>
|
||||
__device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src,
|
||||
F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t IR0 = MapDst2Src{}.Get(I0);
|
||||
constexpr index_t IR1 = MapDst2Src{}.Get(I1);
|
||||
constexpr index_t IR2 = MapDst2Src{}.Get(I2);
|
||||
constexpr index_t IR3 = MapDst2Src{}.Get(I3);
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor_packed(SrcOpLengths{});
|
||||
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
|
||||
|
||||
index_t did[4];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
is -= did[1] * ref_desc.GetStride(I1);
|
||||
|
||||
did[2] = is / ref_desc.GetStride(I2);
|
||||
|
||||
is -= did[2] * ref_desc.GetStride(I2);
|
||||
|
||||
did[3] = is / ref_desc.GetStride(I3);
|
||||
|
||||
const index_t src_index = src_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3]);
|
||||
|
||||
const index_t dst_index =
|
||||
dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
|
||||
|
||||
f(p_src[src_index], p_dst[dst_index]);
|
||||
}
|
||||
|
||||
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
|
||||
|
||||
if(is < ref_desc.GetElementSize())
|
||||
{
|
||||
index_t did[4];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
is -= did[1] * ref_desc.GetStride(I1);
|
||||
|
||||
did[2] = is / ref_desc.GetStride(I2);
|
||||
|
||||
is -= did[2] * ref_desc.GetStride(I2);
|
||||
|
||||
did[3] = is / ref_desc.GetStride(I3);
|
||||
|
||||
const index_t src_index =
|
||||
src_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3]);
|
||||
|
||||
const index_t dst_index =
|
||||
dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
|
||||
|
||||
f(p_src[src_index], p_dst[dst_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t BlockSize, class Float, class DstDesc>
|
||||
__device__ void blockwise_4d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
|
||||
{
|
||||
auto f_set_zero = [](Float& v) { v = Float(0); };
|
||||
|
||||
blockwise_4d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src>
|
||||
__device__ void
|
||||
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src)
|
||||
{
|
||||
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
|
||||
|
||||
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class CopyLengths,
|
||||
index_t DataPerRead>
|
||||
struct Blockwise4dTensorCopy1
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
__device__ constexpr Blockwise4dTensorCopy1()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
static_assert(DataPerRead == 1 ||
|
||||
(SrcDesc{}.GetStride(I3) == 1 && DstDesc{}.GetStride(I3) == 1),
|
||||
"wrong! only support stride3 == 1 if DataPerRead > 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I2) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I2) % DataPerRead == 0,
|
||||
"src and dst stride2 should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
// we allow out-of-bound read from src in D3 dimension,
|
||||
// but we need to make sure dst stride2 is big enough,
|
||||
// so that the out-of-bound write won't contaminate next line in dst
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
constexpr index_t read_per_d3 = math::integer_divide_ceil(L3, DataPerRead);
|
||||
|
||||
static_assert(read_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2),
|
||||
"wrong! out-of-bound write will contaminate next line!\n");
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
|
||||
constexpr index_t read_per_d3 = math::integer_divide_ceil(L3, DataPerRead);
|
||||
|
||||
constexpr auto ref_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<L0, L1, L2, read_per_d3>{});
|
||||
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
auto f_copy = [&](index_t is) {
|
||||
index_t did[4];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
is -= did[1] * ref_desc.GetStride(I1);
|
||||
|
||||
did[2] = is / ref_desc.GetStride(I2);
|
||||
|
||||
is -= did[2] * ref_desc.GetStride(I2);
|
||||
|
||||
did[3] = is / ref_desc.GetStride(I3);
|
||||
|
||||
const index_t src_index =
|
||||
src_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3] * DataPerRead);
|
||||
const index_t dst_index =
|
||||
dst_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3] * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const vector_t*>(p_src + src_index));
|
||||
};
|
||||
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
|
||||
|
||||
f_copy(is);
|
||||
}
|
||||
|
||||
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
|
||||
|
||||
if(is < ref_desc.GetElementSize())
|
||||
{
|
||||
f_copy(is);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class DstOpLengths,
|
||||
class GlobalLowerPads>
|
||||
struct BlockwiseChwnTensorCopyPadded
|
||||
{
|
||||
__device__ void Run(const Float* __restrict__ p_src,
|
||||
index_t c_block_data_begin,
|
||||
index_t ho_block_data_begin,
|
||||
index_t wo_block_data_begin,
|
||||
index_t n_block_data_begin,
|
||||
Float* __restrict__ p_dst,
|
||||
index_t h_block_pad_low,
|
||||
index_t w_block_pad_low,
|
||||
index_t h_block_pad_up,
|
||||
index_t w_block_pad_up) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor_packed(DstOpLengths{});
|
||||
|
||||
constexpr auto h_global_pad_low = GlobalLowerPads{}.Get(I0);
|
||||
constexpr auto w_global_pad_low = GlobalLowerPads{}.Get(I1);
|
||||
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
const Float* p_src_tmp = p_src +
|
||||
src_desc.GetOffsetFromMultiIndex(
|
||||
c_block_data_begin,
|
||||
(ho_block_data_begin + h_block_pad_low) - h_global_pad_low,
|
||||
(wo_block_data_begin + w_block_pad_low) - w_global_pad_low,
|
||||
n_block_data_begin);
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(src_desc, "src_desc: ");
|
||||
print_ConstantTensorDescriptor(dst_desc, "dst_desc: ");
|
||||
print_ConstantTensorDescriptor(ref_desc, "ref_desc: ");
|
||||
|
||||
printf("%u %u, \t"
|
||||
"h_global_pad_low %u w_global_pad_low %u \t"
|
||||
"h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u \t"
|
||||
"\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
h_global_pad_low,
|
||||
w_global_pad_low,
|
||||
h_block_pad_low,
|
||||
w_block_pad_low,
|
||||
h_block_pad_up,
|
||||
w_block_pad_up);
|
||||
}
|
||||
#endif
|
||||
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
|
||||
|
||||
index_t did[4];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
is -= did[1] * ref_desc.GetStride(I1);
|
||||
|
||||
did[2] = is / ref_desc.GetStride(I2);
|
||||
|
||||
is -= did[2] * ref_desc.GetStride(I2);
|
||||
|
||||
did[3] = is / ref_desc.GetStride(I3);
|
||||
|
||||
const index_t bindex = dst_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3]);
|
||||
|
||||
p_dst[bindex] =
|
||||
(did[1] < h_block_pad_low || did[1] + h_block_pad_up >= ref_desc.GetLength(I1) ||
|
||||
did[2] < w_block_pad_low || did[2] + w_block_pad_up >= ref_desc.GetLength(I2))
|
||||
? Float(0)
|
||||
: p_src_tmp[src_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3])];
|
||||
}
|
||||
|
||||
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
|
||||
|
||||
if(is < ref_desc.GetElementSize())
|
||||
{
|
||||
index_t did[4];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
is -= did[1] * ref_desc.GetStride(I1);
|
||||
|
||||
did[2] = is / ref_desc.GetStride(I2);
|
||||
|
||||
is -= did[2] * ref_desc.GetStride(I2);
|
||||
|
||||
did[3] = is / ref_desc.GetStride(I3);
|
||||
|
||||
const index_t bindex =
|
||||
dst_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3]);
|
||||
|
||||
p_dst[bindex] =
|
||||
(did[1] < h_block_pad_low ||
|
||||
did[1] + h_block_pad_up >= ref_desc.GetLength(I1) ||
|
||||
did[2] < w_block_pad_low || did[2] + w_block_pad_up >= ref_desc.GetLength(I2))
|
||||
? Float(0)
|
||||
: p_src_tmp[src_desc.GetOffsetFromMultiIndex(
|
||||
did[0], did[1], did[2], did[3])];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// starting point need to be aligned to float4 or float2 or float
|
||||
// stride3 need to be 1 for both source and destination
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class CopyLengths,
|
||||
class ThreadPerDims,
|
||||
index_t DataPerRead>
|
||||
struct Blockwise4dTensorCopy3
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
index_t mSrcMyThreadOffset;
|
||||
index_t mDstMyThreadOffset;
|
||||
|
||||
__device__ Blockwise4dTensorCopy3()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
static_assert(DataPerRead == 1 ||
|
||||
(SrcDesc{}.GetStride(I3) == 1 && DstDesc{}.GetStride(I3) == 1),
|
||||
"wrong! only support stride3 == 1 if DataPerRead > 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(
|
||||
SrcDesc{}.GetStride(I2) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I2) % DataPerRead == 0,
|
||||
"wrong! src and dst stride2 should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
|
||||
|
||||
// we allow out-of-bound read from src in D3 dimension,
|
||||
// but we need to make sure dst stride is big enough,
|
||||
// so that the out-of-bound write won't contaminate next line in dst
|
||||
constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
|
||||
static_assert(nloop_d3 * thread_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2),
|
||||
"wrong! out-of-bound write will contaminate next line!\n");
|
||||
|
||||
static_assert(L0 % thread_per_d0 == 0 && L1 % thread_per_d1 == 0 && L2 % thread_per_d2 == 0,
|
||||
"wrong! L0, L1, L2 should be divided evenly!\n");
|
||||
|
||||
static_assert(BlockSize >= thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3,
|
||||
"wrrong! BlockSize is not big enough for ThreadPerDims!");
|
||||
|
||||
constexpr index_t num_active_thread =
|
||||
reduce_on_sequence(ThreadPerDims{}, math::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(ThreadPerDims{});
|
||||
const auto thread_multi_id =
|
||||
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
|
||||
|
||||
mSrcMyThreadOffset = SrcDesc{}.GetOffsetFromMultiIndex(thread_multi_id[0],
|
||||
thread_multi_id[1],
|
||||
thread_multi_id[2],
|
||||
thread_multi_id[3] * DataPerRead);
|
||||
|
||||
mDstMyThreadOffset = DstDesc{}.GetOffsetFromMultiIndex(thread_multi_id[0],
|
||||
thread_multi_id[1],
|
||||
thread_multi_id[2],
|
||||
thread_multi_id[3] * DataPerRead);
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
|
||||
|
||||
constexpr index_t num_active_thread =
|
||||
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = L2 / thread_per_d2;
|
||||
constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
|
||||
#pragma unroll
|
||||
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
|
||||
{
|
||||
const index_t src_offset = SrcDesc{}.GetOffsetFromMultiIndex(
|
||||
iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2,
|
||||
iloop_d3 * thread_per_d3 * DataPerRead);
|
||||
|
||||
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(
|
||||
iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2,
|
||||
iloop_d3 * thread_per_d3 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_offset + mDstMyThreadOffset])) =
|
||||
*(reinterpret_cast<const vector_t*>(
|
||||
&p_src[src_offset + mSrcMyThreadOffset]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ constexpr index_t GetRegisterBufferSize() const
|
||||
{
|
||||
static_assert(is_same<Float, float>{}, "wrong! only support float!\n");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = L2 / thread_per_d2;
|
||||
constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
|
||||
return DataPerRead * nloop_d0 * nloop_d1 * nloop_d2 * nloop_d3;
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterBuffer(const Float* __restrict__ p_src,
|
||||
Float* __restrict__ p_clipboard) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
|
||||
|
||||
constexpr index_t num_active_thread =
|
||||
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = L2 / thread_per_d2;
|
||||
constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
|
||||
constexpr auto clipboard_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<nloop_d0, nloop_d1, nloop_d2, nloop_d3 * DataPerRead>{});
|
||||
|
||||
#pragma unroll
|
||||
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
|
||||
{
|
||||
const index_t src_offset = SrcDesc{}.GetOffsetFromMultiIndex(
|
||||
iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2,
|
||||
iloop_d3 * thread_per_d3 * DataPerRead);
|
||||
|
||||
const index_t clipboard_offset = clipboard_desc.GetOffsetFromMultiIndex(
|
||||
iloop_d0, iloop_d1, iloop_d2, iloop_d3 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_clipboard[clipboard_offset])) =
|
||||
*(reinterpret_cast<const vector_t*>(
|
||||
&p_src[src_offset + mSrcMyThreadOffset]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterBuffer(const Float* __restrict__ p_clipboard,
|
||||
Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
|
||||
|
||||
constexpr index_t num_active_thread =
|
||||
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = L2 / thread_per_d2;
|
||||
constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
|
||||
constexpr auto clipboard_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<nloop_d0, nloop_d1, nloop_d2, nloop_d3 * DataPerRead>{});
|
||||
|
||||
#pragma unroll
|
||||
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
|
||||
{
|
||||
const index_t clipboard_offset = clipboard_desc.GetOffsetFromMultiIndex(
|
||||
iloop_d0, iloop_d1, iloop_d2, iloop_d3 * DataPerRead);
|
||||
|
||||
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(
|
||||
iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2,
|
||||
iloop_d3 * thread_per_d3 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_offset + mDstMyThreadOffset])) =
|
||||
*(reinterpret_cast<const vector_t*>(&p_clipboard[clipboard_offset]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src>
|
||||
struct Blockwise4dTensorCopyReorder1
|
||||
{
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
|
||||
|
||||
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
#endif
|
||||
@@ -4,7 +4,6 @@
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "tensor_view.hpp"
|
||||
#include "tensor_coordinate_deprecated.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
|
||||
|
||||
@@ -484,14 +483,8 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
address_space_t ThreadBufferAddressSpace = address_space_t::generic>
|
||||
__device__ void RunLoadThreadBuffer(const TData* p_block_src, TData* p_thread_buffer) const
|
||||
{
|
||||
#if 0
|
||||
mThreadwiseLoad.Run(p_block_src, p_thread_buffer);
|
||||
#else // tweaking
|
||||
mThreadwiseLoad.template Run_optimized_address_calculation<TData,
|
||||
BlockSrcAddressSpace,
|
||||
ThreadBufferAddressSpace>(
|
||||
p_block_src, p_thread_buffer);
|
||||
#endif
|
||||
mThreadwiseLoad.Run<TData, BlockSrcAddressSpace, ThreadBufferAddressSpace>(p_block_src,
|
||||
p_thread_buffer);
|
||||
}
|
||||
|
||||
template <typename TData,
|
||||
@@ -499,14 +492,8 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
address_space_t BlockDstAddressSpace = address_space_t::generic>
|
||||
__device__ void RunStoreThreadBuffer(const TData* p_thread_buffer, TData* p_block_dst) const
|
||||
{
|
||||
#if 0
|
||||
mThreadwiseStore.Run(p_thread_buffer, p_block_dst);
|
||||
#else // tweaking
|
||||
mThreadwiseStore.template Run_optimized_address_calculation<TData,
|
||||
ThreadBufferAddressSpace,
|
||||
BlockDstAddressSpace>(
|
||||
p_thread_buffer, p_block_dst);
|
||||
#endif
|
||||
mThreadwiseStore.Run<TData, ThreadBufferAddressSpace, BlockDstAddressSpace>(p_thread_buffer,
|
||||
p_block_dst);
|
||||
}
|
||||
|
||||
template <typename TData,
|
||||
@@ -563,130 +550,6 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
ThreadwiseStore mThreadwiseStore;
|
||||
};
|
||||
|
||||
// this version use TensorView and TensorCoordinate_deprecated
|
||||
template <index_t BlockSize,
|
||||
typename SrcTensor,
|
||||
typename DstTensor,
|
||||
typename SliceLengths,
|
||||
typename SubLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorAccessDim,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct BlockwiseGenericTensorSliceCopy_v3
|
||||
{
|
||||
static constexpr index_t nDim = SrcTensor::GetNumOfDimension();
|
||||
using data_type = remove_cv_t<typename SrcTensor::data_type>;
|
||||
|
||||
using SrcCoordinate = typename SrcTensor::coordinate_type;
|
||||
using DstCoordinate = typename DstTensor::coordinate_type;
|
||||
|
||||
__device__ constexpr BlockwiseGenericTensorSliceCopy_v3(SrcTensor src_block,
|
||||
SrcCoordinate src_block_slice_origin,
|
||||
DstTensor dst_block,
|
||||
DstCoordinate dst_block_slice_origin)
|
||||
: mThreadBuffer{make_TensorView(ThreadBufferDesc{}, mpBuffer)}
|
||||
{
|
||||
static_assert(
|
||||
nDim == SrcTensor::GetNumOfDimension() && nDim == DstTensor::GetNumOfDimension() &&
|
||||
nDim == SliceLengths::GetSize() && nDim == SubLengths::GetSize() &&
|
||||
nDim == ThreadClusterLengths::GetSize() &&
|
||||
nDim == ThreadClusterArrangeOrder::GetSize() &&
|
||||
nDim == SrcDimAccessOrder::GetSize() && nDim == DstDimAccessOrder::GetSize(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(is_same<SliceLengths, decltype(SubLengths{} * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(is_same<remove_cv_t<typename SrcTensor::data_type>,
|
||||
remove_cv_t<typename DstTensor::data_type>>{},
|
||||
"wrong! type conversion not supported yet");
|
||||
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
|
||||
ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
|
||||
|
||||
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
|
||||
"wrong! BlockSize not consistent with ThreadClusterLengths");
|
||||
|
||||
const auto thread_cluster_id =
|
||||
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
|
||||
|
||||
const auto data_cluster_id =
|
||||
reorder_array_given_old2new(thread_cluster_id, ThreadClusterArrangeOrder{});
|
||||
|
||||
const auto thread_data_id_begin = data_cluster_id * SubLengths{};
|
||||
|
||||
mThreadwiseLoad = ThreadwiseLoad(src_block,
|
||||
src_block_slice_origin + thread_data_id_begin,
|
||||
mThreadBuffer,
|
||||
make_zero_array<index_t, nDim>());
|
||||
|
||||
mThreadwiseStore = ThreadwiseStore(mThreadBuffer,
|
||||
make_zero_array<index_t, nDim>(),
|
||||
dst_block,
|
||||
dst_block_slice_origin + thread_data_id_begin);
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterBuffer() { mThreadwiseLoad.Run(); }
|
||||
|
||||
__device__ void RunStoreRegisterBuffer() const { mThreadwiseStore.Run(); }
|
||||
|
||||
__device__ void Run()
|
||||
{
|
||||
mThreadwiseLoad.Run();
|
||||
mThreadwiseStore.Run();
|
||||
}
|
||||
|
||||
template <typename T, bool PositiveDirection>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection> positive_direction)
|
||||
{
|
||||
mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction);
|
||||
}
|
||||
|
||||
template <typename T, bool PositiveDirection>
|
||||
__device__ void
|
||||
MoveDstSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection> positive_direction)
|
||||
{
|
||||
mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction);
|
||||
}
|
||||
|
||||
private:
|
||||
using ThreadBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{}));
|
||||
using ThreadBufferTensor = NormalTensorView<ThreadBufferDesc, data_type>;
|
||||
|
||||
using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v3r1<SrcTensor,
|
||||
ThreadBufferTensor,
|
||||
SubLengths,
|
||||
SrcDimAccessOrder,
|
||||
SrcDimAccessOrder,
|
||||
SrcVectorAccessDim,
|
||||
SrcVectorAccessDim,
|
||||
SrcDataPerAccess,
|
||||
1>;
|
||||
|
||||
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v3r1<ThreadBufferTensor,
|
||||
DstTensor,
|
||||
SubLengths,
|
||||
DstDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
DstVectorAccessDim,
|
||||
DstVectorAccessDim,
|
||||
1,
|
||||
DstDataPerAccess>;
|
||||
|
||||
data_type mpBuffer[ThreadBufferDesc::GetElementSpace()];
|
||||
|
||||
ThreadBufferTensor mThreadBuffer;
|
||||
|
||||
ThreadwiseLoad mThreadwiseLoad;
|
||||
ThreadwiseStore mThreadwiseStore;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,298 +0,0 @@
|
||||
#ifndef CK_BLOCKWISE_TENSOR_SLICE_COPY_HPP
|
||||
#define CK_BLOCKWISE_TENSOR_SLICE_COPY_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcLengths,
|
||||
class SrcSubLengths,
|
||||
class SrcClusterLengths,
|
||||
class MapDst2Src,
|
||||
class MapThreadCluster2SrcCluster,
|
||||
index_t SrcDataPerRead,
|
||||
index_t DstDataPerWrite>
|
||||
struct BlockwiseTensorSliceReorderCopy_v3
|
||||
{
|
||||
static constexpr index_t nDim = SrcLengths::GetSize();
|
||||
|
||||
index_t mThreadSrcOffset;
|
||||
index_t mThreadDstOffset;
|
||||
|
||||
__device__
|
||||
BlockwiseTensorSliceReorderCopy_v3(Array<index_t, nDim> src_block_data_multi_id_begin,
|
||||
Array<index_t, nDim> dst_block_data_multi_id_begin)
|
||||
{
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr auto src_lengths = SrcLengths{};
|
||||
|
||||
constexpr auto map_dst2src = MapDst2Src{};
|
||||
|
||||
constexpr auto src_sub_lengths = SrcSubLengths{};
|
||||
constexpr auto dst_sub_lengths = src_sub_lengths.ReorderGivenNew2Old(map_dst2src);
|
||||
|
||||
constexpr auto map_thread_cluster_2_src_cluster = MapThreadCluster2SrcCluster{};
|
||||
|
||||
constexpr auto src_cluster_lengths = SrcClusterLengths{};
|
||||
constexpr auto thread_cluster_lengths =
|
||||
src_cluster_lengths.ReorderGivenNew2Old(map_thread_cluster_2_src_cluster);
|
||||
|
||||
constexpr auto thread_cluster_desc =
|
||||
make_ConstantTensorDescriptor_packed(thread_cluster_lengths);
|
||||
|
||||
// sanity check: data type
|
||||
static_assert(is_same<Float, float>{}, "wrong! only support float for now!\n");
|
||||
|
||||
// sanity check: nDim
|
||||
static_assert(SrcDesc::GetNumOfDimension() == nDim &&
|
||||
DstDesc::GetNumOfDimension() == nDim && SrcLengths::GetSize() == nDim &&
|
||||
SrcSubLengths::GetSize() == nDim &&
|
||||
SrcClusterLengths::GetSize() == nDim && MapDst2Src::GetSize() == nDim &&
|
||||
MapThreadCluster2SrcCluster::GetSize() == nDim,
|
||||
"wrong! nDim is not consistent\n");
|
||||
|
||||
// sanity check: BlockSize
|
||||
constexpr index_t num_active_thread = thread_cluster_desc.GetElementSize();
|
||||
|
||||
static_assert(BlockSize >= num_active_thread,
|
||||
"wrong! BlockSize is not big enough for ThreadPerDims!");
|
||||
|
||||
// sanity check: work division
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr auto I = decltype(IDim){};
|
||||
constexpr index_t src_len = src_lengths.Get(I);
|
||||
constexpr index_t src_sub_len = src_sub_lengths.Get(I);
|
||||
constexpr index_t src_cluster_len = src_cluster_lengths.Get(I);
|
||||
static_assert(src_len % (src_sub_len * src_cluster_len) == 0,
|
||||
"wrong! cannot evenly divide Src tensor lengths");
|
||||
});
|
||||
|
||||
// sanity check: src read
|
||||
static_assert(SrcDataPerRead == 1 || SrcDataPerRead == 2 || SrcDataPerRead == 4,
|
||||
"wrong! only support SrcDataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDataPerRead == 1 || src_desc.GetStride(Number<nDim - 1>{}) == 1,
|
||||
"wrong! only support src.stride(nDim-1) == 1 if SrcDataPerRead > 1!\n");
|
||||
|
||||
static_assert(src_sub_lengths.Get(Number<nDim - 1>{}) % SrcDataPerRead == 0,
|
||||
"wrong! src_sub_lengths[nDim-1] % SrcDataPerRead != 0\n");
|
||||
|
||||
static_assert(src_desc.GetStride(Number<nDim - 2>{}) % SrcDataPerRead == 0,
|
||||
"wrong! should satisfy src_desc.stride(nDim-2) % SrcDataPerRead == 0, to "
|
||||
"keep alignment");
|
||||
|
||||
// sanity check: dst write
|
||||
static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4,
|
||||
"wrong! only support DstDataPerWrite == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(DstDataPerWrite == 1 || dst_desc.GetStride(Number<nDim - 1>{}) == 1,
|
||||
"wrong! only support dst.stride(nDim-1) == 1 if DstDataPerWrite > 1!\n");
|
||||
|
||||
static_assert(dst_sub_lengths.Get(Number<nDim - 1>{}) % DstDataPerWrite == 0,
|
||||
"wrong! dst_sub_lengths[nDim-1] % DstDataPerWrite != 0\n");
|
||||
|
||||
static_assert(dst_desc.GetStride(Number<nDim - 2>{}) % DstDataPerWrite == 0,
|
||||
"wrong! should satisfy dst_desc.stride(nDim-2) % DstDataPerWrite == 0, to "
|
||||
"keep alignment");
|
||||
|
||||
// start dividing work
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const auto thread_multi_id =
|
||||
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
|
||||
|
||||
// compiler: thread_multi_id, src_data_multi_id, dst_data_multi_id, will use separate
|
||||
// regsiters, or only one copy???
|
||||
auto src_data_multi_id =
|
||||
reorder_array_given_old2new(thread_multi_id, map_thread_cluster_2_src_cluster);
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr index_t idim = IDim;
|
||||
// compiler: will it really compute index here, or be merged with
|
||||
// GetOffsetFromMultiIndex and
|
||||
// optimized away???
|
||||
src_data_multi_id(idim) *= src_sub_lengths.Get(IDim);
|
||||
});
|
||||
|
||||
// compiler: will it really compute index here, or be merged with GetOffsetFromMultiIndex
|
||||
// and
|
||||
// optimized away???
|
||||
const auto dst_data_multi_id = reorder_array_given_new2old(src_data_multi_id, map_dst2src);
|
||||
|
||||
mThreadSrcOffset =
|
||||
src_desc.GetOffsetFromMultiIndex(src_data_multi_id + src_block_data_multi_id_begin);
|
||||
|
||||
mThreadDstOffset =
|
||||
dst_desc.GetOffsetFromMultiIndex(dst_data_multi_id + dst_block_data_multi_id_begin);
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(thread_cluster_desc, "thread_cluster_desc: ");
|
||||
}
|
||||
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("id %5u %5u: "
|
||||
"thread_multi_id: %u %u, "
|
||||
"src_block_data_multi_id_begin: %u %u, "
|
||||
"src_data_multi_id: %u %u, "
|
||||
"mThreadSrcOffset %u, mThreadDstOffset %u \n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
thread_multi_id[0],
|
||||
thread_multi_id[1],
|
||||
src_block_data_multi_id_begin[0],
|
||||
src_block_data_multi_id_begin[1],
|
||||
src_data_multi_id[0],
|
||||
src_data_multi_id[1],
|
||||
mThreadSrcOffset,
|
||||
mThreadDstOffset);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegisterBufferSize()
|
||||
{
|
||||
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
|
||||
|
||||
constexpr auto src_data_per_cluster_per_dims =
|
||||
thread_sub_tensor_lengths * SrcClusterLengths{};
|
||||
|
||||
constexpr auto repeat_lengths = transform_sequences(
|
||||
math::integer_divide_ceiler<index_t>{}, SrcLengths{}, src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths;
|
||||
|
||||
constexpr auto thread_tensor_desc =
|
||||
make_ConstantTensorDescriptor_packed(thread_tensor_lengths);
|
||||
|
||||
return thread_tensor_desc.GetElementSpace();
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterBuffer(const Float* __restrict__ p_src,
|
||||
Float* __restrict__ p_clipboard) const
|
||||
{
|
||||
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
|
||||
|
||||
constexpr auto src_data_per_cluster_per_dims =
|
||||
thread_sub_tensor_lengths * SrcClusterLengths{};
|
||||
|
||||
constexpr auto repeat_lengths = transform_sequences(
|
||||
math::integer_divide_ceiler<index_t>{}, SrcLengths{}, src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths;
|
||||
|
||||
constexpr auto thread_tensor_desc =
|
||||
make_ConstantTensorDescriptor_packed(thread_tensor_lengths);
|
||||
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
|
||||
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
|
||||
|
||||
constexpr auto src_data_multi_id = repeat_multi_id * src_data_per_cluster_per_dims;
|
||||
|
||||
constexpr auto clipboard_data_multi_id = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
|
||||
constexpr index_t src_offset = SrcDesc{}.GetOffsetFromMultiIndex(src_data_multi_id);
|
||||
constexpr index_t clipboard_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id);
|
||||
|
||||
threadwise_tensor_slice_copy(SrcDesc{},
|
||||
p_src + src_offset + mThreadSrcOffset,
|
||||
thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
thread_sub_tensor_lengths,
|
||||
Number<SrcDataPerRead>{});
|
||||
});
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterBuffer(const Float* __restrict__ p_clipboard,
|
||||
Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
|
||||
|
||||
constexpr auto src_data_per_cluster_per_dims =
|
||||
thread_sub_tensor_lengths * SrcClusterLengths{};
|
||||
|
||||
constexpr auto repeat_lengths = transform_sequences(
|
||||
math::integer_divide_ceiler<index_t>{}, SrcLengths{}, src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths;
|
||||
|
||||
constexpr auto thread_tensor_desc =
|
||||
make_ConstantTensorDescriptor_packed(thread_tensor_lengths);
|
||||
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
|
||||
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
|
||||
|
||||
constexpr auto clipboard_data_multi_id = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
|
||||
constexpr auto src_data_multi_id = repeat_multi_id * src_data_per_cluster_per_dims;
|
||||
|
||||
// reorder src_data_multi_id to get dst_data_multi_id
|
||||
constexpr auto dst_data_multi_id = src_data_multi_id.ReorderGivenNew2Old(MapDst2Src{});
|
||||
|
||||
constexpr index_t clipboard_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id);
|
||||
|
||||
constexpr index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id);
|
||||
|
||||
// write in the order of dst
|
||||
#if 1
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v2(thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
DstDesc{},
|
||||
p_dst + dst_offset +
|
||||
mThreadDstOffset,
|
||||
thread_sub_tensor_lengths,
|
||||
MapDst2Src{});
|
||||
#else
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v3(thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
DstDesc{},
|
||||
p_dst + dst_offset +
|
||||
mThreadDstOffset,
|
||||
thread_sub_tensor_lengths,
|
||||
MapDst2Src{},
|
||||
Number<DstDataPerWrite>{});
|
||||
#endif
|
||||
});
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
Float p_clipboard[GetRegisterBufferSize()];
|
||||
|
||||
RunLoadRegisterBuffer(p_src, p_clipboard);
|
||||
RunStoreRegisterBuffer(p_clipboard, p_dst);
|
||||
}
|
||||
|
||||
// this function doesn't do santiy check on whether the slicing window is out of the boundary
|
||||
// of the tensor being sliced
|
||||
template <index_t IDim_, index_t StepSize, bool PositiveDirection>
|
||||
__device__ void MoveSlicingWindowOnSourceTensor(
|
||||
Number<IDim_>, Number<StepSize>, integral_constant<bool, PositiveDirection> direction)
|
||||
{
|
||||
constexpr auto IDim = Number<IDim_>{};
|
||||
|
||||
static_if<PositiveDirection>{}([&](auto fwd) {
|
||||
mThreadSrcOffset += StepSize * fwd(SrcDesc{}).GetStride(IDim);
|
||||
}).Else([&](auto fwd) { mThreadSrcOffset -= StepSize * fwd(SrcDesc{}).GetStride(IDim); });
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,60 +0,0 @@
|
||||
#ifndef CK_THREADWISE_4D_TENSOR_OP_HPP
|
||||
#define CK_THREADWISE_4D_TENSOR_OP_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class Float, class Desc, class IDim, class NShift>
|
||||
__device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDim, NShift)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto desc = Desc{};
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_shift_down: ");
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr index_t nshift = NShift::mValue;
|
||||
|
||||
constexpr index_t did0_end =
|
||||
is_same<decltype(I0), IDim>{} ? desc.GetLength(I0) - nshift : desc.GetLength(I0);
|
||||
|
||||
constexpr index_t did1_end =
|
||||
is_same<decltype(I1), IDim>{} ? desc.GetLength(I1) - nshift : desc.GetLength(I1);
|
||||
|
||||
constexpr index_t did2_end =
|
||||
is_same<decltype(I2), IDim>{} ? desc.GetLength(I2) - nshift : desc.GetLength(I2);
|
||||
|
||||
constexpr index_t did3_end =
|
||||
is_same<decltype(I3), IDim>{} ? desc.GetLength(I3) - nshift : desc.GetLength(I3);
|
||||
|
||||
for(index_t did0 = 0; did0 < did0_end; ++did0)
|
||||
{
|
||||
for(index_t did1 = 0; did1 < did1_end; ++did1)
|
||||
{
|
||||
for(index_t did2 = 0; did2 < did2_end; ++did2)
|
||||
{
|
||||
for(index_t did3 = 0; did3 < did3_end; ++did3)
|
||||
{
|
||||
const index_t dindex = desc.GetOffsetFromMultiIndex(did0, did1, did2, did3);
|
||||
|
||||
const index_t sindex = dindex + nshift * desc.GetStride(IDim{});
|
||||
|
||||
p[dindex] = p[sindex];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -4,7 +4,6 @@
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "tensor_view.hpp"
|
||||
#include "tensor_coordinate_deprecated.hpp"
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1
|
||||
@@ -600,18 +599,18 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
|
||||
|
||||
// Read vector from src.
|
||||
// 1. Source code version can take src of all kinds of memory-space
|
||||
// 2. Inline asm versions using global_load or buffer_load can only take
|
||||
// 2. Intrinsic version using buffer_load can only take
|
||||
// src from global-memory
|
||||
//
|
||||
// Commemt for loading from global-memory:
|
||||
// When
|
||||
// When:
|
||||
// 1) using source code, in order for compiler to emit optimal
|
||||
// load instruction, or
|
||||
// 2) using inline asm (global_load or buffer_load), in order
|
||||
// for inline asm to be valid,
|
||||
// 2) using buffer_load intrinsic, in order for ISA to be valid,
|
||||
// following assumptions need to be satisfied:
|
||||
// 1. p_src need to be block-invariant (assumption)
|
||||
// 2. src_normal_offset must be calculatd at compile time (guaranteed)
|
||||
// 2. src_normal_offset must be calculatd at compile time (guaranteed by
|
||||
// algorithm)
|
||||
// 3. src_merged_offset can be runtime value (no assumption imposed)
|
||||
static_if<SrcAddressSpace == address_space_t::global>{}([&](auto) {
|
||||
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
|
||||
@@ -698,18 +697,18 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
|
||||
|
||||
// Write vector into dst.
|
||||
// 1. Source code version can take dst of all kinds of memory-space
|
||||
// 2. Inline asm versions using global_store or buffer_store can only take
|
||||
// 2. Intrinsic version using buffer_store can only take
|
||||
// dst from global-memory
|
||||
//
|
||||
// Commemt for storing into global-memory:
|
||||
// When
|
||||
// When:
|
||||
// 1) using source code, in order for compiler to emit optimal
|
||||
// store instruction, or
|
||||
// 2) using inline asm (global_store or buffer_store), in order
|
||||
// for inline asm to be valid,
|
||||
// 2) using buffer_store, intrinsic in order ISA to be valid
|
||||
// following assumptions need to be satisfied:
|
||||
// 1. p_dst need to be block-invariant (assumption)
|
||||
// 2. dst_normal_offset must be calculatd at compile time (guaranteed)
|
||||
// 2. dst_normal_offset must be calculatd at compile time (guaranteed by
|
||||
// algorithm)
|
||||
// 3. dst_merged_offset can be runtime value (no assumption imposed)
|
||||
static_if<DstAddressSpace == address_space_t::global>{}([&](auto) {
|
||||
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
|
||||
@@ -751,152 +750,5 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
|
||||
DstCoordinate mDstSliceOrigin;
|
||||
};
|
||||
|
||||
// this version use TensorView and TensorCoordinate_deprecated
|
||||
template <typename SrcTensor,
|
||||
typename DstTensor,
|
||||
typename SliceLengths,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorAccessDim,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct ThreadwiseGenericTensorSliceCopy_v3r1
|
||||
{
|
||||
static constexpr index_t nDim = SrcTensor::GetNumOfDimension();
|
||||
using data_type = remove_cv_t<typename SrcTensor::data_type>;
|
||||
|
||||
using SrcCoordinate = typename SrcTensor::coordinate_type;
|
||||
using DstCoordinate = typename DstTensor::coordinate_type;
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v3r1(SrcTensor src,
|
||||
SrcCoordinate src_slice_origin,
|
||||
DstTensor dst,
|
||||
DstCoordinate dst_slice_origin)
|
||||
: mSrc{src},
|
||||
mDst{dst},
|
||||
mSrcSlice{src.Slice(src_slice_origin, SliceLengths{})},
|
||||
mDstSlice{dst.Slice(dst_slice_origin, SliceLengths{})}
|
||||
{
|
||||
static_assert(nDim == SrcTensor::GetNumOfDimension() &&
|
||||
nDim == DstTensor::GetNumOfDimension() &&
|
||||
nDim == SliceLengths::GetSize() && nDim == SrcDimAccessOrder::GetSize() &&
|
||||
nDim == DstDimAccessOrder::GetSize(),
|
||||
"wrong! # of dimensions not the same");
|
||||
|
||||
static_assert(is_valid_sequence_map<SrcDimAccessOrder>::value &&
|
||||
is_valid_sequence_map<DstDimAccessOrder>::value,
|
||||
"wrong! map is not valid");
|
||||
|
||||
static_assert(is_same<remove_cv_t<typename SrcTensor::data_type>,
|
||||
remove_cv_t<typename DstTensor::data_type>>{},
|
||||
"wrong! type conversion is not supported yet");
|
||||
|
||||
static_assert(decltype(mSrcSlice)::IsVectorizationAllowed(Number<SrcVectorAccessDim>{},
|
||||
Number<SrcDataPerAccess>{}) &&
|
||||
decltype(mDstSlice)::IsVectorizationAllowed(Number<DstVectorAccessDim>{},
|
||||
Number<DstDataPerAccess>{}),
|
||||
"wrong! vectorized access is not allowed");
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v3r1()
|
||||
: ThreadwiseGenericTensorSliceCopy_v3r1(
|
||||
SrcTensor{}, SrcCoordinate{}, DstTensor{}, DstCoordinate{})
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void Run() const
|
||||
{
|
||||
// buffer
|
||||
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SrcTensor::GetLengths());
|
||||
data_type p_buffer[buffer_desc.GetElementSpace()];
|
||||
auto buffer = make_TensorView(buffer_desc, p_buffer);
|
||||
|
||||
// copy data from src into buffer
|
||||
{
|
||||
using src_vector_t = typename vector_type<data_type, SrcDataPerAccess>::MemoryType;
|
||||
|
||||
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{};
|
||||
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
|
||||
|
||||
auto src_slice_vectorized =
|
||||
mSrcSlice.Vectorize(src_vector_access_dim, src_data_per_access);
|
||||
|
||||
ford<decltype(src_slice_vectorized.GetLengths()), SrcDimAccessOrder>{}(
|
||||
[&](auto src_vector_id) {
|
||||
// load vector from src
|
||||
const src_vector_t vector_data = src_slice_vectorized[src_vector_id];
|
||||
|
||||
// unpack vector into buffer
|
||||
auto src_scalar_id = src_vector_id;
|
||||
src_scalar_id(src_vector_access_dim) *= src_data_per_access;
|
||||
|
||||
for(index_t i = 0; i < SrcDataPerAccess; ++i)
|
||||
{
|
||||
auto id = make_zero_array<index_t, nDim>();
|
||||
id(src_vector_access_dim) = i;
|
||||
|
||||
buffer(src_scalar_id + id) =
|
||||
reinterpret_cast<const data_type*>(&vector_data)[i];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// copy data from buffer into dst
|
||||
{
|
||||
using dst_vector_t = typename vector_type<data_type, DstDataPerAccess>::MemoryType;
|
||||
|
||||
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{};
|
||||
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
|
||||
|
||||
auto dst_slice_vectorized =
|
||||
mDstSlice.Vectorize(dst_vector_access_dim, dst_data_per_access);
|
||||
|
||||
ford<decltype(dst_slice_vectorized.GetLengths()), DstDimAccessOrder>{}(
|
||||
[&](auto dst_vector_id) {
|
||||
|
||||
dst_vector_t vector_data{};
|
||||
|
||||
// pack vector from buffer
|
||||
auto dst_scalar_id = dst_vector_id;
|
||||
dst_scalar_id(dst_vector_access_dim) *= dst_data_per_access;
|
||||
|
||||
for(index_t i = 0; i < DstDataPerAccess; ++i)
|
||||
{
|
||||
auto id = make_zero_array<index_t, nDim>();
|
||||
id(dst_vector_access_dim) = i;
|
||||
|
||||
reinterpret_cast<data_type*>(&vector_data)[i] = buffer[dst_scalar_id + id];
|
||||
}
|
||||
|
||||
// write vector into dst
|
||||
dst_slice_vectorized(dst_vector_id) = vector_data;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// T can be Sequence or Array
|
||||
template <typename T, bool PositiveDirection>
|
||||
__device__ void MoveSrcSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
|
||||
{
|
||||
mSrc.MoveSliceWindow(mSrcSlice, step_sizes, integral_constant<bool, PositiveDirection>{});
|
||||
}
|
||||
|
||||
template <typename T, bool PositiveDirection>
|
||||
__device__ void MoveDstSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
|
||||
{
|
||||
mDst.MoveSliceWindow(mDstSlice, step_sizes, integral_constant<bool, PositiveDirection>{});
|
||||
}
|
||||
|
||||
private:
|
||||
using SrcSlice = decltype(SrcTensor{}.Slice(make_zero_array<index_t, nDim>(), SliceLengths{}));
|
||||
using DstSlice = decltype(DstTensor{}.Slice(make_zero_array<index_t, nDim>(), SliceLengths{}));
|
||||
|
||||
SrcTensor mSrc;
|
||||
DstTensor mDst;
|
||||
SrcSlice mSrcSlice;
|
||||
DstSlice mDstSlice;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
#ifndef CK_THREADWISE_TENSOR_SLICE_COPY_HPP
|
||||
#define CK_THREADWISE_TENSOR_SLICE_COPY_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// need to assume src and dst is aligned
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
|
||||
__device__ void threadwise_tensor_slice_copy(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
Number<DataPerRead>)
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
constexpr index_t nDim = SrcOpLengths::GetSize();
|
||||
|
||||
static_assert(SrcDesc{}.GetNumOfDimension() == nDim && DstDesc{}.GetNumOfDimension() == nDim,
|
||||
"wrong! dimension not consistent");
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor_packed(SrcOpLengths{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(src_desc, "src_desc");
|
||||
print_ConstantTensorDescriptor(dst_desc, "dst_desc");
|
||||
print_ConstantTensorDescriptor(ref_desc, "ref_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
static_assert(DataPerRead == 1 || (SrcDesc{}.GetStride(Number<nDim - 1>{}) == 1 &&
|
||||
DstDesc{}.GetStride(Number<nDim - 1>{}) == 1),
|
||||
"wrong! only support stride[nDim-1] == 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(
|
||||
SrcDesc{}.GetStride(Number<nDim - 2>{}) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(Number<nDim - 2>{}) % DataPerRead == 0,
|
||||
"wrong! src and dst stride[nDim-2] should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr index_t L_Back = SrcOpLengths{}.Back();
|
||||
|
||||
static_assert(L_Back % DataPerRead == 0,
|
||||
"wrong! lengths[nDim-1] should be evenly divided by DataPerRead");
|
||||
|
||||
constexpr index_t nRead = L_Back / DataPerRead;
|
||||
|
||||
static_ford<decltype(ref_desc.GetLengths().PopBack())>{}([=](auto Ids) {
|
||||
static_for<0, nRead, 1>{}([&](auto IRead) {
|
||||
constexpr auto multi_id = decltype(Ids){}.PushBack(Number<IRead * DataPerRead>{});
|
||||
|
||||
const index_t src_index = src_desc.GetOffsetFromMultiIndex(multi_id);
|
||||
|
||||
const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(multi_id);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_index])) =
|
||||
*(reinterpret_cast<const vector_t*>(&p_src[src_index]));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// access in order of src
|
||||
template <class SrcData,
|
||||
class DstData,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src>
|
||||
__device__ void
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v1(SrcDesc,
|
||||
const SrcData* __restrict__ p_src,
|
||||
DstDesc,
|
||||
DstData* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src)
|
||||
{
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
ford<SrcOpLengths>{}([&](auto src_multi_id) {
|
||||
const auto dst_multi_id = reorder_array_given_new2old(src_multi_id, MapDst2Src{});
|
||||
|
||||
const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(dst_multi_id);
|
||||
|
||||
const index_t src_index = src_desc.GetOffsetFromMultiIndex(src_multi_id);
|
||||
|
||||
p_dst[dst_index] = p_src[src_index];
|
||||
});
|
||||
}
|
||||
|
||||
// access in order of dst
|
||||
template <class SrcData,
|
||||
class DstData,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src>
|
||||
__device__ void
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v2(SrcDesc,
|
||||
const SrcData* __restrict__ p_src,
|
||||
DstDesc,
|
||||
DstData* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src)
|
||||
{
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr auto dst_op_lengths = SrcOpLengths{}.ReorderGivenNew2Old(MapDst2Src{});
|
||||
|
||||
ford<decltype(dst_op_lengths)>{}([&](auto dst_multi_id) {
|
||||
const auto src_multi_id = reorder_array_given_old2new(dst_multi_id, MapDst2Src{});
|
||||
|
||||
const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(dst_multi_id);
|
||||
|
||||
const index_t src_index = src_desc.GetOffsetFromMultiIndex(src_multi_id);
|
||||
|
||||
p_dst[dst_index] = p_src[src_index];
|
||||
});
|
||||
}
|
||||
|
||||
// access in order of dst
|
||||
// manually pack data into vector before write
|
||||
template <class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src,
|
||||
index_t DstDataPerWrite>
|
||||
__device__ void
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src,
|
||||
Number<DstDataPerWrite>)
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DstDataPerWrite>::MemoryType;
|
||||
|
||||
constexpr index_t nDim = SrcOpLengths::GetSize();
|
||||
|
||||
static_assert(DstDataPerWrite == 1 || DstDesc{}.GetStride(Number<nDim - 1>{}) == 1,
|
||||
"wrong! only support dst.stride[nDim-1] == 1, if DstDataPerWrite != 1");
|
||||
|
||||
static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4,
|
||||
"wrong! only support DstDataPerWrite == 1, 2 or 4");
|
||||
|
||||
static_assert(
|
||||
DstDesc{}.GetStride(Number<nDim - 2>{}) % DstDataPerWrite == 0,
|
||||
"wrong! dst.stride[nDim-2] should be multiple of DstDataPerWrite to keep alignment");
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr auto dst_op_lengths = SrcOpLengths{}.ReorderGivenNew2Old(MapDst2Src{});
|
||||
|
||||
constexpr index_t L_Dst_Back = dst_op_lengths.Back();
|
||||
|
||||
static_assert(L_Dst_Back % DstDataPerWrite == 0,
|
||||
"wrong! dst.lengths[nDim-1] should be evenly divided by DstDataPerWrite");
|
||||
|
||||
constexpr index_t nWrite = L_Dst_Back / DstDataPerWrite;
|
||||
|
||||
ford<decltype(dst_op_lengths.PopBack())>{}([&](auto ids) {
|
||||
static_for<0, nWrite, 1>{}([&](auto IWrite) {
|
||||
vector_t dst_vec_data;
|
||||
|
||||
// pack data
|
||||
static_for<0, DstDataPerWrite, 1>{}([&](auto IDstData) {
|
||||
const auto dst_multi_id = ids.PushBack(IWrite * DstDataPerWrite + IDstData);
|
||||
|
||||
const auto src_multi_id = reorder_array_given_old2new(dst_multi_id, MapDst2Src{});
|
||||
|
||||
const index_t src_index = src_desc.GetOffsetFromMultiIndex(src_multi_id);
|
||||
|
||||
vector_type<Float, DstDataPerWrite>::SetScalar(
|
||||
dst_vec_data, p_src[src_index], IDstData);
|
||||
});
|
||||
|
||||
// write data
|
||||
const auto dst_multi_id = ids.PushBack(IWrite * DstDataPerWrite);
|
||||
|
||||
const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(dst_multi_id);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_index])) = dst_vec_data;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -48,36 +48,6 @@ struct type_convert
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
__device__ void fused_multiply_accumulate(T& d, const T& s0, const T& s1)
|
||||
{
|
||||
d += s0 * s1;
|
||||
}
|
||||
|
||||
#if 0
|
||||
__device__ void fused_multiply_accumulate(half& d, const half& s0, const half& s1) { d += s0 * s1; }
|
||||
|
||||
__device__ void fused_multiply_accumulate(half& d, const half2& s0, const half2& s1)
|
||||
{
|
||||
d += s0.x * s1.x;
|
||||
d += s0.y * s1.y;
|
||||
}
|
||||
|
||||
__device__ void fused_multiply_accumulate(float& d, const half2& s0, const half2& s1)
|
||||
{
|
||||
d += s0.x * s1.x + s0.y * s1.y;
|
||||
}
|
||||
|
||||
__device__ void fused_multiply_accumulate(char& d, const char& s0, const char& s1) { d += s0 * s1; }
|
||||
|
||||
// TODO:: this interface is misleading, s0, s1 are actually int8x4
|
||||
// need to make a better interface
|
||||
__device__ void fused_multiply_accumulate(int32_t& d, const int32_t& s0, const int32_t& s1)
|
||||
{
|
||||
d = __dp4a(s0, s1, d);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_convolution_kernel_wrapper.hpp"
|
||||
//#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp"
|
||||
|
||||
template <class T,
|
||||
@@ -20,7 +19,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
index_t nrepeat)
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -171,46 +170,42 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
constexpr auto gridwise_conv =
|
||||
#if 0
|
||||
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
#else
|
||||
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
#endif
|
||||
<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
EPerBlock,
|
||||
GemmNRepeat,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>{};
|
||||
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
EPerBlock,
|
||||
GemmNRepeat,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>{};
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_convolution_kernel_wrapper.hpp"
|
||||
//#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp"
|
||||
|
||||
template <typename T,
|
||||
@@ -24,7 +23,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
index_t nrepeat)
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_convolution_kernel_wrapper.hpp"
|
||||
//#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp"
|
||||
|
||||
using namespace ck;
|
||||
@@ -22,7 +21,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
index_t nrepeat)
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_convolution_kernel_wrapper.hpp"
|
||||
//#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp"
|
||||
|
||||
template <class T,
|
||||
@@ -24,7 +23,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded(InDesc,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
index_t nrepeat)
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
|
||||
@@ -1,296 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc, class LowerPads, class UpperPads>
|
||||
void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
LowerPads,
|
||||
UpperPads,
|
||||
index_t nrepeat)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_nchw_desc = InDesc{};
|
||||
constexpr auto wei_kcyx_desc = WeiDesc{};
|
||||
constexpr auto out_nkhw_desc = OutDesc{};
|
||||
|
||||
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
|
||||
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
|
||||
|
||||
// reorder weight
|
||||
auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
|
||||
ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: ");
|
||||
|
||||
Tensor<T> wei_cyxk(make_TensorDescriptor(wei_cyxk_desc));
|
||||
|
||||
auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) {
|
||||
wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
// reorder input
|
||||
auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence<C, Hi, Wi, N>{});
|
||||
ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: ");
|
||||
|
||||
Tensor<T> in_chwn(make_TensorDescriptor(in_chwn_desc));
|
||||
|
||||
auto f_reorder_nchw2chwn = [&](auto n, auto c, auto hi, auto wi) {
|
||||
in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_reorder_nchw2chwn, N, C, Hi, Wi)(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
// output
|
||||
auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence<K, Ho, Wo, N>{});
|
||||
ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: ");
|
||||
|
||||
Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc));
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_chwn_device_buf(data_sz * in_chwn.mDesc.GetElementSpace());
|
||||
DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace());
|
||||
DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace());
|
||||
|
||||
in_chwn_device_buf.ToDevice(in_chwn.mData.data());
|
||||
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
|
||||
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
|
||||
|
||||
#if 0
|
||||
constexpr index_t NPerBlock = 1;
|
||||
constexpr index_t KPerBlock = 1;
|
||||
constexpr index_t CPerBlock = 1;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr index_t NPerThread = 1;
|
||||
constexpr index_t KPerThread = 1;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr index_t WeiBlockCopyThreadPerDim0 = 1;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim1 = 1;
|
||||
|
||||
constexpr index_t BlockSize = 8;
|
||||
#elif 1
|
||||
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// 3x3 58x58, NKC = 16,256,128
|
||||
constexpr index_t NPerBlock = 8;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 2;
|
||||
constexpr index_t HoPerBlock = 4;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 5x5, 36x36
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 2;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 7x7, 38x38
|
||||
constexpr index_t NPerBlock = 8;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 2;
|
||||
constexpr index_t HoPerBlock = 4;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 3x3, 56x56
|
||||
constexpr index_t NPerBlock = 32;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 1
|
||||
// 3x3 56x56, NKC = 16,256,128, with padding
|
||||
// 3x3 28x28, NKC = 16,512,256, with padding
|
||||
// 3x3 20x84, NKC = 16,256,256, with padding
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 2;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr index_t WeiBlockCopyThreadPerDim0 = 2;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim1 = 64;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 5x5 filter, 20x84 image, 1x1 padding
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 1;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// 5x5 filter, 28x28 image, 2x2 padding
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t CPerBlock = 2;
|
||||
constexpr index_t HoPerBlock = 4;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 1x1, 28x28
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 2;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#endif
|
||||
|
||||
constexpr index_t GridSize =
|
||||
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
|
||||
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(
|
||||
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_chwn_desc),
|
||||
decltype(wei_cyxk_desc),
|
||||
decltype(out_khwn_desc),
|
||||
LowerPads,
|
||||
UpperPads,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
|
||||
static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_cyxk_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_khwn_device_buf.GetDeviceBuffer()));
|
||||
|
||||
printf("Elapsed time : %f ms\n", time);
|
||||
usleep(std::min(time * 1000, float(10000)));
|
||||
}
|
||||
|
||||
out_khwn_device_buf.FromDevice(out_khwn.mData.data());
|
||||
|
||||
// reorder output
|
||||
auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) {
|
||||
out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "device.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv.hpp"
|
||||
#include "device_convolution_direct_v2_nchw_kcyx_nkhw.hpp"
|
||||
//#include "device_convolution_direct_v2_nchw_kcyx_nkhw.hpp"
|
||||
//#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp"
|
||||
//#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp"
|
||||
//#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
|
||||
@@ -448,7 +448,7 @@ int main(int argc, char* argv[])
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
nrepeat);
|
||||
#elif 1
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
@@ -490,7 +490,7 @@ int main(int argc, char* argv[])
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
#elif 1
|
||||
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
|
||||
Reference in New Issue
Block a user