mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
initial padding support for nchw
This commit is contained in:
@@ -51,6 +51,7 @@ template <index_t GridSize,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K>
|
||||
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
|
||||
{
|
||||
#if 1
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
@@ -69,20 +70,24 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
constexpr auto in_n_c_hi_wi_global_desc =
|
||||
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
|
||||
constexpr auto wei_k_c_y_x_global_desc =
|
||||
make_native_tensor_descriptor(WeiGlobalDesc::GetLengths(), WeiGlobalDesc::GetStrides());
|
||||
constexpr auto out_n_k_ho_wo_global_desc =
|
||||
make_native_tensor_descriptor(OutGlobalDesc::GetLengths(), OutGlobalDesc::GetStrides());
|
||||
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
@@ -126,30 +131,35 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
|
||||
constexpr auto in_n0_n1_n2_h_w_global_desc =
|
||||
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrideH>{})
|
||||
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrideW>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{})
|
||||
.Extract(Sequence<0, 1, 2, 4, 5>{});
|
||||
// global memory
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
// batch descritpor for device memory
|
||||
constexpr auto in_c_y_x_global_desc =
|
||||
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
|
||||
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
|
||||
.Extract(Sequence<1, 2, 3>{});
|
||||
constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(Unmerge<Sequence<N0, N1, N2>>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
|
||||
|
||||
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
|
||||
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_global_desc),
|
||||
Sequence<0, 1, 2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<3, 6, 7>{},
|
||||
Sequence<5>{});
|
||||
constexpr auto in_e_n1_b_n2_global_desc = transform_tensor_descriptor(
|
||||
in_n0_n1_n2_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{},
|
||||
PassThrough<N1>{},
|
||||
Merge<Sequence<N0, Ho, Wo>>{},
|
||||
PassThrough<N2>{}),
|
||||
make_tuple(Sequence<3, 4, 6>{}, Sequence<1>{}, Sequence<0, 5, 7>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
constexpr auto in_e_n1_b_n2_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
@@ -162,8 +172,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
|
||||
decltype(in_e_n1_b_n2_global_merged_desc),
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(in_e_n1_b_n2_global_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
@@ -180,11 +190,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
|
||||
transform_tensor_descriptor(wei_k_c_y_x_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{}, PassThrough<K>{}),
|
||||
make_tuple(Sequence<1, 2, 3>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<EPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
|
||||
|
||||
@@ -192,7 +205,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
@@ -215,8 +228,11 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
|
||||
// register
|
||||
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
|
||||
|
||||
constexpr auto b_e_n1bn2_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(in_e_n1_b_n2_block_desc.Unfold(I1, I3));
|
||||
constexpr auto b_e_n1bn2_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
in_e_n1_b_n2_block_desc.GetLength(I0),
|
||||
in_e_n1_b_n2_block_desc.GetLength(I1) * in_e_n1_b_n2_block_desc.GetLength(I2) *
|
||||
in_e_n1_b_n2_block_desc.GetLength(I3),
|
||||
in_e_n1_b_n2_block_desc.GetStride(I0));
|
||||
|
||||
// sanity check
|
||||
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
|
||||
@@ -288,21 +304,28 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
|
||||
static_assert(K % (K1 * K2) == 0, "wrong!");
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output memory layout descriptor in register
|
||||
constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc =
|
||||
make_ConstantTensorDescriptor_packed(
|
||||
constexpr auto out_k0_k1_k2_n1_n0_ho_wo_n2_thread_desc =
|
||||
make_native_tensor_descriptor_packed(
|
||||
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{});
|
||||
|
||||
// output tensor descriptor in register, src of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc =
|
||||
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old(
|
||||
Sequence<4, 3, 7, 0, 1, 2, 5, 6>{});
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_ho_wo_thread_desc =
|
||||
reorder_tensor_descriptor_given_upper2lower(out_k0_k1_k2_n1_n0_ho_wo_n2_thread_desc,
|
||||
Sequence<4, 3, 7, 0, 1, 2, 5, 6>{});
|
||||
|
||||
// output memory layout descriptor in device memory, dst of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc =
|
||||
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(Unmerge<Sequence<N / (N1 * N2), N1, N2>>{},
|
||||
Unmerge<Sequence<K / (K1 * K2), K1, K2>>{},
|
||||
PassThrough<Ho>{},
|
||||
PassThrough<Wo>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}, Sequence<6>{}, Sequence<7>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
@@ -317,32 +340,159 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
|
||||
|
||||
// output merged global tensor descriptor, for calculating origin of thread tensor
|
||||
// in global memory
|
||||
constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5),
|
||||
Sequence<3>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 4, 5>{},
|
||||
Sequence<2>{});
|
||||
constexpr auto out_n0_n1_n2_k_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(Unmerge<Sequence<N / (N1 * N2), N1, N2>>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<Ho>{},
|
||||
PassThrough<Wo>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}));
|
||||
|
||||
constexpr auto out_k_n1_b_n2_global_desc = transform_tensor_descriptor(
|
||||
out_n0_n1_n2_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<N1>{},
|
||||
Merge<Sequence<N0, Ho, Wo>>{},
|
||||
PassThrough<N2>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<1>{}, Sequence<0, 4, 5>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// origin of dst in device memory
|
||||
Float* p_out_thread_on_global =
|
||||
p_out_global +
|
||||
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
|
||||
out_k_n1_b_n2_global_desc.CalculateOffset(
|
||||
{k_thread_data_on_global, 0, b_thread_data_on_global, 0});
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v2r1<
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()),
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_ho_wo_thread_desc),
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_ho_wo_global_desc),
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_ho_wo_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 8, 1>::type,
|
||||
arithmetic_sequence_gen<0, 8, 1>::type,
|
||||
7,
|
||||
7,
|
||||
1,
|
||||
1>({0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0})
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
}
|
||||
}
|
||||
#else
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// this is a mess
|
||||
// TODO: find more elegent way of specifying (or calculating) performance parameters
|
||||
constexpr index_t N1 = GemmNRepeat;
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
|
||||
static_assert((N1 * N2 * BPerBlock) %
|
||||
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc =
|
||||
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
|
||||
constexpr auto wei_k_c_y_x_global_desc =
|
||||
make_native_tensor_descriptor(WeiGlobalDesc::GetLengths(), WeiGlobalDesc::GetStrides());
|
||||
constexpr auto out_n_k_h_w_global_desc =
|
||||
make_native_tensor_descriptor(OutGlobalDesc::GetLengths(), OutGlobalDesc::GetStrides());
|
||||
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Hi = in_n_c_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_n_c_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
|
||||
|
||||
constexpr index_t N0 = N / (N1 * N2);
|
||||
|
||||
constexpr index_t B = N0 * Ho * Wo;
|
||||
|
||||
constexpr index_t E = C * Y * X;
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert(ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1,
|
||||
"wrong! global vector load of input tensor is wrong");
|
||||
|
||||
static_assert((X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
constexpr auto in_n_c_hi_wi_global_desc =
|
||||
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
|
||||
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(Unmerge<Sequence<N0, N1, N2>>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
|
||||
|
||||
constexpr auto in_e_n1_b_n2_global_desc = transform_tensor_descriptor(
|
||||
in_n0_n1_n2_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{},
|
||||
PassThrough<N1>{},
|
||||
Merge<Sequence<N0, Ho, Wo>>{},
|
||||
PassThrough<N2>{}),
|
||||
make_tuple(Sequence<3, 4, 6>{}, Sequence<1>{}, Sequence<0, 5, 7>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_tensor_descriptor("in_n_c_hi_wi_global_desc: ", in_n_c_hi_wi_global_desc);
|
||||
print_tensor_descriptor("in_n_c_hip_wip_global_desc: ", in_n_c_hip_wip_global_desc);
|
||||
print_tensor_descriptor("in_n0_n1_n2_c_y_ho_x_wo_global_desc: ",
|
||||
in_n0_n1_n2_c_y_ho_x_wo_global_desc);
|
||||
print_tensor_descriptor("in_e_n1_b_n2_global_desc: ", in_e_n1_b_n2_global_desc);
|
||||
|
||||
auto coord3 = make_tensor_coordinate_v2(in_e_n1_b_n2_global_desc, {1, 1, 1, 1});
|
||||
|
||||
auto idx3 = coord3.GetIndex();
|
||||
auto idx2 = coord3.GetLowerCoordinate().GetIndex();
|
||||
auto idx1 = coord3.GetLowerCoordinate().GetLowerCoordinate().GetIndex();
|
||||
auto idx0 =
|
||||
coord3.GetLowerCoordinate().GetLowerCoordinate().GetLowerCoordinate().GetIndex();
|
||||
|
||||
print_array("idx3: ", idx3);
|
||||
print_array("idx2: ", idx2);
|
||||
print_array("idx1: ", idx1);
|
||||
print_array("idx0: ", idx0);
|
||||
}
|
||||
#endif
|
||||
p_out_global[0] = in_e_n1_b_n2_global_desc.CalculateOffset({0, 0, 10, 0});
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -52,7 +53,7 @@ __host__ __device__ constexpr auto
|
||||
return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(ConstantTensorDescriptor<Ts...>)
|
||||
{
|
||||
using TDesc = ConstantTensorDescriptor<Ts...>;
|
||||
@@ -63,7 +64,18 @@ __host__ __device__ constexpr auto make_ConstantMatrixDescriptor(ConstantTensorD
|
||||
TDesc::GetStrides()[0]>{};
|
||||
}
|
||||
|
||||
template <class TDesc>
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(NativeTensorDescriptor<Ts...>)
|
||||
{
|
||||
using TDesc = NativeTensorDescriptor<Ts...>;
|
||||
static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
|
||||
static_assert(TDesc::GetStrides()[1] == 1, "wrong");
|
||||
return ConstantMatrixDescriptor<TDesc::GetLengths()[0],
|
||||
TDesc::GetLengths()[1],
|
||||
TDesc::GetStrides()[0]>{};
|
||||
}
|
||||
|
||||
template <typename TDesc>
|
||||
__host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s)
|
||||
{
|
||||
printf(
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
namespace ck {
|
||||
|
||||
template <class Lengths>
|
||||
__host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths)
|
||||
__host__ __device__ constexpr auto calculate_tensor_strides_packed_old(Lengths)
|
||||
{
|
||||
return reverse_inclusive_scan_sequence(
|
||||
Lengths{}.PopFront(), math::multiplies<index_t>{}, Number<1>{})
|
||||
@@ -14,12 +14,12 @@ __host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths)
|
||||
}
|
||||
|
||||
template <class Lengths, index_t Align>
|
||||
__host__ __device__ constexpr auto calculate_tensor_strides_aligned(Lengths, Number<Align>)
|
||||
__host__ __device__ constexpr auto calculate_tensor_strides_aligned_old(Lengths, Number<Align>)
|
||||
{
|
||||
constexpr index_t L_back_align =
|
||||
Align * math::integer_divide_ceiler<index_t>{}(Lengths{}.Back(), Align);
|
||||
|
||||
return calculate_tensor_strides_packed(
|
||||
return calculate_tensor_strides_packed_old(
|
||||
Lengths{}.Modify(Number<Lengths{}.GetSize() - 1>{}, Number<L_back_align>{}));
|
||||
}
|
||||
|
||||
@@ -187,7 +187,7 @@ struct ConstantTensorDescriptor
|
||||
{
|
||||
Array<index_t, nDim> multi_id;
|
||||
|
||||
using PackedStrides = decltype(calculate_tensor_strides_packed(GetLengths()));
|
||||
using PackedStrides = decltype(calculate_tensor_strides_packed_old(GetLengths()));
|
||||
|
||||
// calculate index in each of the dimensions in the order of their dimension
|
||||
static_for<0, nDim - 1, 1>{}(lambda_GetMultiIndexFrom1dIndex<PackedStrides>(id, multi_id));
|
||||
@@ -468,7 +468,7 @@ struct ConstantTensorDescriptor
|
||||
|
||||
__host__ __device__ static constexpr auto Pack()
|
||||
{
|
||||
using packed_strides = decltype(calculate_tensor_strides_packed(Lengths{}));
|
||||
using packed_strides = decltype(calculate_tensor_strides_packed_old(Lengths{}));
|
||||
return ConstantTensorDescriptor<Lengths, packed_strides>{};
|
||||
}
|
||||
|
||||
@@ -490,7 +490,7 @@ struct ConstantTensorDescriptor
|
||||
template <class Lengths>
|
||||
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_packed(Lengths)
|
||||
{
|
||||
using Strides = decltype(calculate_tensor_strides_packed(Lengths{}));
|
||||
using Strides = decltype(calculate_tensor_strides_packed_old(Lengths{}));
|
||||
return ConstantTensorDescriptor<Lengths, Strides>{};
|
||||
}
|
||||
|
||||
@@ -503,7 +503,7 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Stride
|
||||
template <class Lengths, index_t Align>
|
||||
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>)
|
||||
{
|
||||
using Strides = decltype(calculate_tensor_strides_aligned(Lengths{}, Number<Align>{}));
|
||||
using Strides = decltype(calculate_tensor_strides_aligned_old(Lengths{}, Number<Align>{}));
|
||||
return ConstantTensorDescriptor<Lengths, Strides>{};
|
||||
}
|
||||
|
||||
|
||||
@@ -24,8 +24,6 @@ struct PassThrough
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<1>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return Sequence<Length>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return Sequence<Length>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
|
||||
@@ -51,11 +49,11 @@ struct PassThrough
|
||||
}
|
||||
};
|
||||
|
||||
// LowLengths: Sequence<...>
|
||||
template <typename LowLengths, typename LeftPads, typename RightPads>
|
||||
// LowerLengths: Sequence<...>
|
||||
template <typename LowerLengths, typename LeftPads, typename RightPads>
|
||||
struct Pad
|
||||
{
|
||||
static constexpr index_t nDim = LowLengths::Size();
|
||||
static constexpr index_t nDim = LowerLengths::Size();
|
||||
|
||||
using LowerIndex = MultiIndex<nDim>;
|
||||
using UpperIndex = MultiIndex<nDim>;
|
||||
@@ -64,11 +62,9 @@ struct Pad
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
{
|
||||
return GetLowerLengths() + LeftPads{} + RightPads{};
|
||||
return LowerLengths{} + LeftPads{} + RightPads{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
|
||||
@@ -98,7 +94,7 @@ struct Pad
|
||||
|
||||
// only check if there is right-padding
|
||||
static_if<(RightPads::At(idim) != 0)>{}([&](auto) {
|
||||
flag = flag || idx_up[idim] >= LeftPads::At(idim) + LowLengths::At(idim);
|
||||
flag = flag || idx_up[idim] >= LeftPads::At(idim) + LowerLengths::At(idim);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -106,11 +102,11 @@ struct Pad
|
||||
}
|
||||
};
|
||||
|
||||
// LowLengths: Sequence<...>
|
||||
template <typename LowLengths>
|
||||
// LowerLengths: Sequence<...>
|
||||
template <typename LowerLengths>
|
||||
struct Merge
|
||||
{
|
||||
static constexpr index_t nDimLow = LowLengths::Size();
|
||||
static constexpr index_t nDimLow = LowerLengths::Size();
|
||||
static constexpr index_t nDimUp = 1;
|
||||
|
||||
using LowerIndex = MultiIndex<nDimLow>;
|
||||
@@ -120,12 +116,10 @@ struct Merge
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
{
|
||||
return Sequence<accumulate_on_sequence(
|
||||
GetLowerLengths(), math::multiplies<index_t>{}, Number<1>{})>{};
|
||||
LowerLengths{}, math::multiplies<index_t>{}, Number<1>{})>{};
|
||||
}
|
||||
|
||||
// emulate constexpr lambda
|
||||
@@ -158,11 +152,11 @@ struct Merge
|
||||
|
||||
constexpr auto pseudo_low_strides =
|
||||
reverse_inclusive_scan_sequence(
|
||||
GetLowerLengths().PopFront(), math::multiplies<index_t>{}, Number<1>{})
|
||||
LowerLengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
|
||||
.PushBack(Number<1>{});
|
||||
|
||||
// calculate index in each of the dimensions in the order of their dimension
|
||||
#if 1 // would compile to same ISA?
|
||||
#if 1 // would these 2 versions be compiled to same ISA?
|
||||
// calculate index in each of the dimensions in the order of their dimension
|
||||
static_for<0, nDimLow - 1, 1>{}(
|
||||
lambda_CalculateLowerIndex<decltype(pseudo_low_strides)>(itmp, idx_low));
|
||||
|
||||
@@ -176,16 +170,75 @@ struct Merge
|
||||
}
|
||||
|
||||
// idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date
|
||||
// If idx_up_diff is known at compile-time, many calculations can be optimized
|
||||
// away by compiler
|
||||
// This function assume idx_low_old is not out-of-bound
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
|
||||
const UpperIndex& /* idx_up_old */,
|
||||
const LowerIndex& idx_low_old)
|
||||
{
|
||||
LowerIndex idx_low_diff;
|
||||
// do nothing if idx_up_diff == 0
|
||||
if(idx_up_diff[0] == 0)
|
||||
{
|
||||
return make_zero_array<index_t, nDimLow>();
|
||||
}
|
||||
|
||||
// not implemeneted
|
||||
// CalculateLowerIndex(idx_up_diff) has multiple integer divisions.
|
||||
// If idx_up_diff is known at compile-time, the calculation can
|
||||
// be done at compile-time. However, if idx_up_diff is only known
|
||||
// at run-time, then the calculation will also be computed at
|
||||
// run-time, and can be very expensive.
|
||||
LowerIndex idx_low_new = idx_low_old + CalculateLowerIndex(idx_up_diff);
|
||||
|
||||
return idx_low_diff;
|
||||
if(idx_up_diff[0] > 0)
|
||||
{
|
||||
bool carry = false;
|
||||
|
||||
// do carry check in reversed order, starting from lowest dimension
|
||||
// don't check the highest dimension
|
||||
static_for<0, nDimLow, 1>{}([&](auto ireverse) {
|
||||
constexpr index_t i = nDimLow - 1 - ireverse;
|
||||
|
||||
if(carry)
|
||||
{
|
||||
++idx_low_new(i);
|
||||
}
|
||||
|
||||
carry = false;
|
||||
|
||||
if(idx_low_new[i] >= LowerLengths::At(i))
|
||||
{
|
||||
idx_low_new(i) -= LowerLengths::At(i);
|
||||
carry = true;
|
||||
}
|
||||
});
|
||||
}
|
||||
else if(idx_up_diff[0] < 0)
|
||||
{
|
||||
bool borrow = false;
|
||||
|
||||
// do borrow check in reversed order, starting from lowest dimension
|
||||
// don't check the highest dimension
|
||||
static_for<0, nDimLow, 1>{}([&](auto ireverse) {
|
||||
constexpr index_t i = nDimLow - 1 - ireverse;
|
||||
|
||||
if(borrow)
|
||||
{
|
||||
--idx_low_new(i);
|
||||
}
|
||||
|
||||
borrow = false;
|
||||
|
||||
if(idx_low_new[i] < 0)
|
||||
{
|
||||
idx_low_new(i) += LowerLengths::At(i);
|
||||
borrow = true;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return idx_low_new - idx_low_old;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
|
||||
@@ -198,12 +251,12 @@ struct Merge
|
||||
}
|
||||
};
|
||||
|
||||
// UpLengths: Sequence<...>
|
||||
template <typename UpLengths>
|
||||
// UpperLengths: Sequence<...>
|
||||
template <typename UpperLengths>
|
||||
struct Unmerge
|
||||
{
|
||||
static constexpr index_t nDimLow = 1;
|
||||
static constexpr index_t nDimUp = UpLengths::Size();
|
||||
static constexpr index_t nDimUp = UpperLengths::Size();
|
||||
|
||||
using LowerIndex = MultiIndex<nDimLow>;
|
||||
using UpperIndex = MultiIndex<nDimUp>;
|
||||
@@ -212,23 +265,16 @@ struct Unmerge
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths()
|
||||
{
|
||||
constexpr index_t low_length =
|
||||
accumulate_on_sequence(UpLengths{}, math::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
return Sequence<low_length>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return UpperLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
|
||||
{
|
||||
LowerIndex idx_low{0};
|
||||
|
||||
constexpr auto pseudo_up_strides =
|
||||
typename sequence_reverse_inclusive_scan<UpLengths, math::multiplies<index_t>, 1>::
|
||||
type{};
|
||||
reverse_inclusive_scan_sequence(
|
||||
UpperLengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
|
||||
.PushBack(Number<1>{});
|
||||
|
||||
static_for<0, nDimUp, 1>{}(
|
||||
[&](auto idim) { idx_low(0) += idx_up[idim] * pseudo_up_strides[idim]; });
|
||||
@@ -245,47 +291,45 @@ struct Unmerge
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
// TODO: should this function be here? should it be specific for padding check?
|
||||
__host__ __device__ static constexpr bool
|
||||
IsUpperIndexInPaddingArea(const UpperIndex& /* idx_up */)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
// UpLengths: Sequence<...>
|
||||
// UpperLengths: Sequence<...>
|
||||
// Coefficients: Sequence<...>
|
||||
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
|
||||
template <index_t LowLength, typename UpLengths, typename Coefficients>
|
||||
template <typename UpperLengths, typename Coefficients>
|
||||
struct Embed
|
||||
{
|
||||
static constexpr index_t nDimLow = 1;
|
||||
static constexpr index_t nDimUp = UpLengths::Size();
|
||||
static constexpr index_t nDimUp = UpperLengths::Size();
|
||||
|
||||
using LowerIndex = MultiIndex<nDimLow>;
|
||||
using UpperIndex = MultiIndex<nDimUp>;
|
||||
|
||||
__host__ __device__ explicit constexpr Embed()
|
||||
{
|
||||
static_assert(UpLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1,
|
||||
static_assert(UpperLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1,
|
||||
"wrong! # of dimensions not consistent");
|
||||
|
||||
constexpr index_t low_id_max =
|
||||
Coefficients::Back() + accumulate_on_sequence(UpLengths{} * Coefficients::PopBack(),
|
||||
math::plus<index_t>{},
|
||||
Number<0>{});
|
||||
|
||||
static_assert(low_id_max < LowLength, "wrong! lower-id will go out of range");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return Sequence<LowLength>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return UpperLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
|
||||
{
|
||||
LowerIndex idx_low(Coefficients{}[nDimUp]);
|
||||
|
||||
static_for<0, nDimUp, 1>{}(
|
||||
[&](auto idim) { idx_low[0] += idx_up[idim] * Coefficients{}[idim]; });
|
||||
[&](auto idim) { idx_low(0) += idx_up[idim] * Coefficients{}[idim]; });
|
||||
|
||||
return idx_low;
|
||||
}
|
||||
@@ -298,12 +342,18 @@ struct Embed
|
||||
LowerIndex idx_low_diff{0};
|
||||
|
||||
static_for<0, nDimUp, 1>{}(
|
||||
[&](auto idim) { idx_low_diff[0] += idx_up_diff[idim] * Coefficients{}[idim]; });
|
||||
[&](auto idim) { idx_low_diff(0) += idx_up_diff[idim] * Coefficients{}[idim]; });
|
||||
|
||||
return idx_low_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
IsUpperIndexInPaddingArea(const UpperIndex& /* idx_up */)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -207,10 +207,12 @@ struct TransformedTensorDescriptor
|
||||
return LowTensorDescriptor{};
|
||||
}
|
||||
|
||||
#if 0
|
||||
__host__ __device__ static constexpr auto GetLowerLengths()
|
||||
{
|
||||
return GetLowerTensorDescriptor().GetLengths();
|
||||
}
|
||||
#endif
|
||||
|
||||
struct lambda_GetUpperLengths
|
||||
{
|
||||
@@ -383,35 +385,5 @@ struct TransformedTensorDescriptor
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t... Lengths, index_t... Strides>
|
||||
__host__ __device__ constexpr auto make_native_tensor_descriptor(Sequence<Lengths...>,
|
||||
Sequence<Strides...>)
|
||||
{
|
||||
return NativeTensorDescriptor<NativeDimension<Lengths, Strides>...>{};
|
||||
}
|
||||
|
||||
template <typename Lengths>
|
||||
__host__ __device__ constexpr auto make_native_tensor_descriptor_packed(Lengths)
|
||||
{
|
||||
constexpr auto strides = reverse_inclusive_scan_sequence(
|
||||
Lengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
|
||||
.PushBack(Number<1>{});
|
||||
|
||||
return make_native_tensor_descriptor(Lengths{}, strides);
|
||||
}
|
||||
|
||||
template <typename LowTensorDescriptor,
|
||||
typename Transforms,
|
||||
typename LowDimensionIds,
|
||||
typename UpDimensionIds>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_tensor_descriptor(LowTensorDescriptor, Transforms, LowDimensionIds, UpDimensionIds)
|
||||
{
|
||||
return TransformedTensorDescriptor<LowTensorDescriptor,
|
||||
Transforms,
|
||||
LowDimensionIds,
|
||||
UpDimensionIds>{};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -6,6 +6,96 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename Lengths>
|
||||
__host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths)
|
||||
{
|
||||
return reverse_inclusive_scan_sequence(
|
||||
Lengths{}.PopFront(), math::multiplies<index_t>{}, Number<1>{})
|
||||
.PushBack(Number<1>{});
|
||||
}
|
||||
|
||||
template <typename Lengths, index_t Align>
|
||||
__host__ __device__ constexpr auto calculate_tensor_strides_aligned(Lengths, Number<Align>)
|
||||
{
|
||||
constexpr index_t L_back_align =
|
||||
Align * math::integer_divide_ceiler<index_t>{}(Lengths{}.Back(), Align);
|
||||
|
||||
return calculate_tensor_strides_packed(
|
||||
Lengths{}.Modify(Number<Lengths{}.GetSize() - 1>{}, Number<L_back_align>{}));
|
||||
}
|
||||
|
||||
template <index_t... Lengths, index_t... Strides>
|
||||
__host__ __device__ constexpr auto make_native_tensor_descriptor(Sequence<Lengths...>,
|
||||
Sequence<Strides...>)
|
||||
{
|
||||
return NativeTensorDescriptor<NativeDimension<Lengths, Strides>...>{};
|
||||
}
|
||||
|
||||
template <typename Lengths>
|
||||
__host__ __device__ constexpr auto make_native_tensor_descriptor_packed(Lengths)
|
||||
{
|
||||
constexpr auto strides = calculate_tensor_strides_packed(Lengths{});
|
||||
|
||||
return make_native_tensor_descriptor(Lengths{}, strides);
|
||||
}
|
||||
|
||||
template <typename Lengths, index_t Align>
|
||||
__host__ __device__ constexpr auto make_native_tensor_descriptor_aligned(Lengths, Number<Align>)
|
||||
{
|
||||
constexpr auto strides = calculate_tensor_strides_aligned(Lengths{}, Number<Align>{});
|
||||
return make_native_tensor_descriptor(Lengths{}, strides);
|
||||
}
|
||||
|
||||
template <typename LowTensorDescriptor,
|
||||
typename Transforms,
|
||||
typename LowDimensionIds,
|
||||
typename UpDimensionIds>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_tensor_descriptor(LowTensorDescriptor, Transforms, LowDimensionIds, UpDimensionIds)
|
||||
{
|
||||
return TransformedTensorDescriptor<LowTensorDescriptor,
|
||||
Transforms,
|
||||
LowDimensionIds,
|
||||
UpDimensionIds>{};
|
||||
}
|
||||
|
||||
template <typename LowerTensorDescriptor,
|
||||
index_t... LowerLengths,
|
||||
index_t... LowerDimensionIds,
|
||||
index_t... UpperDimensionIds>
|
||||
__host__ __device__ constexpr auto reorder_tensor_descriptor_impl(LowerTensorDescriptor,
|
||||
Sequence<LowerLengths...>,
|
||||
Sequence<LowerDimensionIds...>,
|
||||
Sequence<UpperDimensionIds...>)
|
||||
{
|
||||
return TransformedTensorDescriptor<LowerTensorDescriptor,
|
||||
Tuple<PassThrough<LowerLengths>...>,
|
||||
Tuple<Sequence<LowerDimensionIds>...>,
|
||||
Tuple<Sequence<UpperDimensionIds>...>>{};
|
||||
}
|
||||
|
||||
template <typename LowerTensorDescriptor, typename MapLower2Upper>
|
||||
__host__ __device__ constexpr auto
|
||||
reorder_tensor_descriptor_given_lower2upper(LowerTensorDescriptor, MapLower2Upper)
|
||||
{
|
||||
static_assert(is_valid_sequence_map<MapLower2Upper>{},
|
||||
"wrong! MapLower2Upper is not a valid map");
|
||||
|
||||
return reorder_tensor_descriptor_impl(
|
||||
LowerTensorDescriptor{},
|
||||
LowerTensorDescriptor::GetLengths(),
|
||||
typename arithmetic_sequence_gen<0, LowerTensorDescriptor::GetNumOfDimension(), 1>::type{},
|
||||
MapLower2Upper{});
|
||||
}
|
||||
|
||||
template <typename LowerTensorDescriptor, typename MapUpper2Lower>
|
||||
__host__ __device__ constexpr auto
|
||||
reorder_tensor_descriptor_given_upper2lower(LowerTensorDescriptor, MapUpper2Lower)
|
||||
{
|
||||
return reorder_tensor_descriptor_given_lower2upper(
|
||||
LowerTensorDescriptor{}, typename sequence_map_inverse<MapUpper2Lower>::type{});
|
||||
}
|
||||
|
||||
template <typename... NativeDimensions>
|
||||
__host__ __device__ void
|
||||
print_tensor_descriptor(const char* s, const NativeTensorDescriptor<NativeDimensions...>& desc)
|
||||
|
||||
@@ -951,10 +951,10 @@ struct ThreadwiseGenericTensorSliceCopy_v3r1
|
||||
// The dimension access order should be the same on src and dst.
|
||||
// It is designed for cases, where one of src and dst is register, and
|
||||
// the other is device memory or LDS
|
||||
template <class SrcDesc,
|
||||
class DstDesc,
|
||||
class SliceLengths,
|
||||
class DimAccessOrder,
|
||||
template <typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t VectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
|
||||
@@ -91,8 +91,8 @@ int main(int argc, char* argv[])
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 34;
|
||||
constexpr index_t WI = 34;
|
||||
constexpr index_t HI = 32;
|
||||
constexpr index_t WI = 32;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -100,8 +100,8 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 1x1 filter, 8x8 image
|
||||
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
|
||||
|
||||
Reference in New Issue
Block a user