mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
adding implicit gemm
This commit is contained in:
@@ -1,6 +1,22 @@
|
||||
#pragma once
|
||||
#include "common.cuh"
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
|
||||
{
|
||||
return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <unsigned S0, unsigned S1, unsigned S2, unsigned S3>
|
||||
__host__ __device__ constexpr auto calculate_full_lengths(Sequence<S0, S1, S2, S3>)
|
||||
{
|
||||
static_assert((S0 % S1 == 0) && (S1 % S2 == 0) && (S2 % S3 == 0), "cannot be evenly divided!");
|
||||
|
||||
return Sequence<1, S0 / S1, S1 / S2, S2 / S3>{};
|
||||
}
|
||||
|
||||
template <class Lengths, class Strides>
|
||||
struct ConstantTensorDescriptor
|
||||
{
|
||||
@@ -69,24 +85,14 @@ struct ConstantTensorDescriptor
|
||||
static_assert(nDim == 4, "nDim is not 4");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto Condense() const
|
||||
{
|
||||
constexpr auto default_strides = calculate_default_strides(Lengths{});
|
||||
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
|
||||
}
|
||||
};
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
|
||||
{
|
||||
return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <unsigned S0, unsigned S1, unsigned S2, unsigned S3>
|
||||
__host__ __device__ constexpr auto calculate_full_lengths(Sequence<S0, S1, S2, S3>)
|
||||
{
|
||||
static_assert((S0 % S1 == 0) && (S1 % S2 == 0) && (S2 % S3 == 0), "cannot be evenly divided!");
|
||||
|
||||
return Sequence<1, S0 / S1, S1 / S2, S2 / S3>{};
|
||||
}
|
||||
|
||||
template <class Lengths>
|
||||
__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths)
|
||||
{
|
||||
@@ -124,4 +130,4 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
desc.GetStride(I1),
|
||||
desc.GetStride(I2),
|
||||
desc.GetStride(I3));
|
||||
}
|
||||
}
|
||||
@@ -83,31 +83,31 @@ template <unsigned BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class RefDesc,
|
||||
class Reorder,
|
||||
class SrcOpLengths,
|
||||
class DstFromSrcReorder,
|
||||
class F>
|
||||
__device__ void
|
||||
blockwise_4d_tensor_pointwise_operation_binary_reorder(SrcDesc,
|
||||
Float* const __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
RefDesc,
|
||||
Reorder,
|
||||
F f)
|
||||
__device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
SrcDesc,
|
||||
Float* const __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
DstFromSrcReorder,
|
||||
F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr unsigned IT0 = Reorder{}.Get(I0);
|
||||
constexpr unsigned IT1 = Reorder{}.Get(I1);
|
||||
constexpr unsigned IT2 = Reorder{}.Get(I2);
|
||||
constexpr unsigned IT3 = Reorder{}.Get(I3);
|
||||
constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0);
|
||||
constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1);
|
||||
constexpr unsigned IR2 = DstFromSrcReorder{}.Get(I2);
|
||||
constexpr unsigned IR3 = DstFromSrcReorder{}.Get(I3);
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = RefDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
|
||||
|
||||
constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
@@ -133,7 +133,7 @@ blockwise_4d_tensor_pointwise_operation_binary_reorder(SrcDesc,
|
||||
|
||||
const unsigned aindex = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
|
||||
|
||||
const unsigned bindex = dst_desc.Get1dIndex(did[IT0], did[IT1], did[IT2], did[IT3]);
|
||||
const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
|
||||
|
||||
f(p_src[aindex], p_dst[bindex]);
|
||||
}
|
||||
@@ -164,7 +164,7 @@ blockwise_4d_tensor_pointwise_operation_binary_reorder(SrcDesc,
|
||||
|
||||
const unsigned aindex = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]);
|
||||
|
||||
const unsigned bindex = dst_desc.Get1dIndex(did[IT0], did[IT1], did[IT2], did[IT3]);
|
||||
const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
|
||||
|
||||
f(p_src[aindex], p_dst[bindex]);
|
||||
}
|
||||
@@ -183,23 +183,28 @@ template <unsigned BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class RefDesc,
|
||||
class Reorder>
|
||||
__device__ void blockwise_4d_tensor_copy_reorder(
|
||||
SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, RefDesc, Reorder)
|
||||
class SrcOpLengths,
|
||||
class DstFromSrcReorder>
|
||||
__device__ void
|
||||
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
Float* const __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
DstFromSrcReorder)
|
||||
{
|
||||
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
|
||||
|
||||
blockwise_4d_tensor_pointwise_operation_binary_reorder<BlockSize>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, RefDesc{}, Reorder{}, f_copy);
|
||||
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
|
||||
}
|
||||
|
||||
template <unsigned BlockSize, class Float, class SrcDesc, class DstDesc, class RefDesc>
|
||||
template <unsigned BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
__device__ void blockwise_4d_tensor_copy(
|
||||
SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, RefDesc)
|
||||
SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths)
|
||||
{
|
||||
constexpr auto reorder = Sequence<0, 1, 2, 3>{};
|
||||
constexpr auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{};
|
||||
|
||||
blockwise_4d_tensor_copy_reorder<BlockSize>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, RefDesc{}, reorder);
|
||||
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
|
||||
}
|
||||
|
||||
@@ -30,6 +30,8 @@ using Number = Constant<unsigned, N>;
|
||||
template <unsigned... Is>
|
||||
struct Sequence
|
||||
{
|
||||
using Type = Sequence<Is...>;
|
||||
|
||||
static constexpr unsigned nDim = sizeof...(Is);
|
||||
|
||||
const unsigned mData[nDim] = {Is...};
|
||||
@@ -40,44 +42,24 @@ struct Sequence
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
template <unsigned I0, unsigned I1>
|
||||
__host__ __device__ constexpr auto Reorder(Number<I0>, Number<I1>) const
|
||||
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
|
||||
__host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence<I0, I1, I2, I3>) const
|
||||
{
|
||||
constexpr unsigned IR0 = Get(Number<I0>{});
|
||||
constexpr unsigned IR1 = Get(Number<I1>{});
|
||||
constexpr auto old_sequence = Type{};
|
||||
|
||||
return Sequence<IR0, IR1>{};
|
||||
}
|
||||
constexpr unsigned NR0 = old_sequence.mData[I0];
|
||||
constexpr unsigned NR1 = old_sequence.mData[I1];
|
||||
constexpr unsigned NR2 = old_sequence.mData[I2];
|
||||
constexpr unsigned NR3 = old_sequence.mData[I3];
|
||||
|
||||
template <unsigned I0, unsigned I1, unsigned I2>
|
||||
__host__ __device__ constexpr auto Reorder(Number<I0>, Number<I1>, Number<I2>) const
|
||||
{
|
||||
constexpr unsigned IR0 = Get(Number<I0>{});
|
||||
constexpr unsigned IR1 = Get(Number<I1>{});
|
||||
constexpr unsigned IR2 = Get(Number<I2>{});
|
||||
|
||||
return Sequence<IR0, IR1, IR2>{};
|
||||
return Sequence<NR0, NR1, NR2, NR3>{};
|
||||
}
|
||||
|
||||
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
|
||||
__host__ __device__ constexpr auto Reorder(Number<I0>, Number<I1>, Number<I2>, Number<I3>) const
|
||||
__host__ __device__ constexpr auto ReorderByPutOldToNew(Sequence<I0, I1, I2, I3>) const
|
||||
{
|
||||
constexpr unsigned IR0 = Get(Number<I0>{});
|
||||
constexpr unsigned IR1 = Get(Number<I1>{});
|
||||
constexpr unsigned IR2 = Get(Number<I2>{});
|
||||
constexpr unsigned IR3 = Get(Number<I3>{});
|
||||
|
||||
return Sequence<IR0, IR1, IR2, IR3>{};
|
||||
}
|
||||
|
||||
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
|
||||
__host__ __device__ constexpr auto Reorder(Sequence<I0, I1, I2, I3>) const
|
||||
{
|
||||
constexpr unsigned IR0 = Get(Number<I0>{});
|
||||
constexpr unsigned IR1 = Get(Number<I1>{});
|
||||
constexpr unsigned IR2 = Get(Number<I2>{});
|
||||
constexpr unsigned IR3 = Get(Number<I3>{});
|
||||
|
||||
return Sequence<IR0, IR1, IR2, IR3>{};
|
||||
// don't know how to implement this
|
||||
printf("Sequence::ReorderByPutOldToNew not implemented");
|
||||
assert(false);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -159,7 +159,7 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
|
||||
wi_block_data_begin),
|
||||
in_block_desc,
|
||||
p_in_block,
|
||||
in_block_desc);
|
||||
in_block_desc.GetLengths());
|
||||
|
||||
// copy weight tensor to LDS
|
||||
blockwise_4d_tensor_copy<BlockSize>(
|
||||
@@ -167,7 +167,7 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
|
||||
p_wei_global + wei_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
wei_block_desc,
|
||||
p_wei_block,
|
||||
wei_block_desc);
|
||||
wei_block_desc.GetLengths());
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -209,5 +209,5 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_thread_desc);
|
||||
out_thread_desc.GetLengths());
|
||||
}
|
||||
|
||||
@@ -74,17 +74,39 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
const unsigned hi_block_data_begin = ho_block_data_begin;
|
||||
const unsigned wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// tensor view of blockwise input and weight in LDS
|
||||
constexpr auto wei_srck_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<S, R, CPerBlock, KPerBlock>{});
|
||||
// tensor view of un-reorderd blockwise input and weight (imaginary)
|
||||
constexpr auto in_nchw_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{});
|
||||
|
||||
constexpr auto in_chwn_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{});
|
||||
constexpr auto wei_kcsr_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<KPerBlock, CPerBlock, S, R>{});
|
||||
|
||||
// tensor view of reordered blockwise input and weight in LDS
|
||||
constexpr auto reorder_chwn_from_nchw = Sequence<1, 2, 3, 0>{};
|
||||
constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor(
|
||||
in_nchw_block_desc.GetLengths().ReorderByGetNewFromOld(reorder_chwn_from_nchw));
|
||||
|
||||
constexpr auto reorder_srck_from_kcsr = Sequence<2, 3, 1, 0>{};
|
||||
constexpr auto wei_srck_block_desc = make_ConstantTensorDescriptor(
|
||||
wei_kcsr_block_desc.GetLengths().ReorderByGetNewFromOld(reorder_srck_from_kcsr));
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_hkwn_thread_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<HoPerThread, KPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_nchw_block_desc, "in_nchw_block_desc");
|
||||
print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(wei_kcsr_block_desc, "wei_kcsr_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_srck_block_desc, "wei_srck_block_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
@@ -97,7 +119,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
const auto b_cxwn_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_chwn_block_desc.GetStride(I1)>{}); // constexpr doesn't compile
|
||||
Number<in_chwn_block_desc.GetStride(I0)>{}); // constexpr doesn't compile
|
||||
|
||||
const auto c_kxwn_thread_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThread>{}, Number<WoPerThread * NPerThread>{}); // constexpr doesn't compile
|
||||
@@ -137,11 +159,10 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1);
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
#if 1
|
||||
// input: global mem to LDS,
|
||||
// convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N]
|
||||
constexpr auto reorder_nchw2chwn = Sequence<3, 0, 1, 2>{};
|
||||
|
||||
blockwise_4d_tensor_copy_reorder<BlockSize>(
|
||||
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
|
||||
in_nchw_global_desc,
|
||||
p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
@@ -149,21 +170,22 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
wi_block_data_begin),
|
||||
in_chwn_block_desc,
|
||||
p_in_block,
|
||||
in_chwn_block_desc,
|
||||
reorder_nchw2chwn);
|
||||
in_nchw_block_desc.GetLengths(),
|
||||
reorder_chwn_from_nchw);
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
// weight: global mem to LDS,
|
||||
// convert 4d-tensor wei[K,C,S,R] to matrix wei_matrix[S*R*C,K]
|
||||
constexpr auto reorder_kcsr2srck = Sequence<3, 2, 0, 1>{};
|
||||
|
||||
blockwise_4d_tensor_copy_reorder<BlockSize>(
|
||||
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
|
||||
wei_kcsr_global_desc,
|
||||
p_wei_global +
|
||||
wei_kcsr_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
wei_srck_block_desc,
|
||||
p_wei_block,
|
||||
wei_srck_block_desc,
|
||||
reorder_kcsr2srck);
|
||||
wei_kcsr_block_desc.GetLengths(),
|
||||
reorder_srck_from_kcsr);
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -187,10 +209,10 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
const unsigned wo_thread_data_begin = matrix_c_index.row_begin / NPerThread;
|
||||
|
||||
// output: register to global mem,
|
||||
// convert matrix out_matrix[Ho*K,Wo*N] to 4d-tensor out[N,K,Ho,Wo]
|
||||
constexpr auto reorder_hkwn2nkhw = Sequence<2, 1, 3, 0>{};
|
||||
// convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo]
|
||||
constexpr auto reorder_nkhw_from_hkwn = Sequence<3, 1, 0, 2>{};
|
||||
|
||||
threadwise_4d_tensor_copy_reorder(
|
||||
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(
|
||||
out_hkwn_thread_desc,
|
||||
p_out_thread,
|
||||
out_nkhw_global_desc,
|
||||
@@ -198,6 +220,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_hkwn_thread_desc,
|
||||
reorder_hkwn2nkhw);
|
||||
out_hkwn_thread_desc.GetLengths(),
|
||||
reorder_nkhw_from_hkwn);
|
||||
}
|
||||
219
src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh
Normal file
219
src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh
Normal file
@@ -0,0 +1,219 @@
|
||||
#pragma once
|
||||
#include "common.cuh"
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
#include "ConstantMatrixDescriptor.cuh"
|
||||
#include "blockwise_tensor_op.cuh"
|
||||
#include "threadwise_tensor_op.cuh"
|
||||
#include "gemm.cuh"
|
||||
|
||||
template <unsigned GridSize,
|
||||
unsigned BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
unsigned NPerBlock,
|
||||
unsigned KPerBlock,
|
||||
unsigned CPerBlock,
|
||||
unsigned HoPerBlock,
|
||||
unsigned WoPerBlock,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread,
|
||||
unsigned HoPerThread,
|
||||
unsigned WoPerThread>
|
||||
__global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc,
|
||||
Float* const __restrict__ p_in_global,
|
||||
WeiGlobalDesc,
|
||||
Float* const __restrict__ p_wei_global,
|
||||
OutGlobalDesc,
|
||||
Float* __restrict__ p_out_global)
|
||||
{
|
||||
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
|
||||
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
|
||||
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
|
||||
constexpr unsigned NPerThread = NPerBlock;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_nchw_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_srck_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_nkhw_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr unsigned S = wei_srck_global_desc.GetLength(I0);
|
||||
constexpr unsigned R = wei_srck_global_desc.GetLength(I1);
|
||||
|
||||
constexpr unsigned HiPerBlock = HoPerBlock + S - 1;
|
||||
constexpr unsigned WiPerBlock = WoPerBlock + R - 1;
|
||||
|
||||
// divide block work: NCHW
|
||||
constexpr unsigned NBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
|
||||
constexpr unsigned KBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
|
||||
constexpr unsigned HBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr unsigned WBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
|
||||
|
||||
unsigned itmp = get_block_1d_id();
|
||||
const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
|
||||
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
|
||||
const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork);
|
||||
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
|
||||
const unsigned h_block_work_id = itmp / WBlockWork;
|
||||
const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork;
|
||||
|
||||
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const unsigned wo_block_data_begin = w_block_work_id * HoPerBlock;
|
||||
|
||||
const unsigned hi_block_data_begin = ho_block_data_begin;
|
||||
const unsigned wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// tensor view of un-reorderd blockwise input and weight (imaginary)
|
||||
constexpr auto in_nchw_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{});
|
||||
|
||||
constexpr auto wei_srck_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<S, R, CPerBlock, KPerBlock>{});
|
||||
|
||||
// tensor view of reordered blockwise input and weight in LDS
|
||||
constexpr auto reorder_chwn_from_nchw = Sequence<1, 2, 3, 0>{};
|
||||
constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor(
|
||||
in_nchw_block_desc.GetLengths().ReorderByGetNewFromOld(reorder_chwn_from_nchw));
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_hkwn_thread_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<HoPerThread, KPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_nchw_block_desc, "in_nchw_block_desc");
|
||||
print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(wei_kcsr_block_desc, "wei_kcsr_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_srck_block_desc, "wei_srck_block_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[S,R,C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N]
|
||||
const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}); // constexpr doesn't compile
|
||||
|
||||
const auto b_cxwn_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_chwn_block_desc.GetStride(I0)>{}); // constexpr doesn't compile
|
||||
|
||||
const auto c_kxwn_thread_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThread>{}, Number<WoPerThread * NPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
auto f_accum = [](auto& c, auto& ab) { c += ab; };
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxwn_block_mtx_desc),
|
||||
decltype(c_kxwn_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
in_chwn_block_desc.GetStride(I1),
|
||||
out_hkwn_thread_desc.GetStride(
|
||||
I1),
|
||||
HoPerBlock,
|
||||
HoPerThread,
|
||||
CPerThread,
|
||||
decltype(f_accum)>{};
|
||||
|
||||
// LDS
|
||||
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace();
|
||||
constexpr unsigned wei_block_size = wei_srck_block_desc.GetElementSpace();
|
||||
|
||||
__shared__ Float p_in_block[in_block_size];
|
||||
__shared__ Float p_wei_block[wei_block_size];
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()];
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread);
|
||||
|
||||
for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1);
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
#if 1
|
||||
// input: global mem to LDS,
|
||||
// convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N]
|
||||
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
|
||||
in_nchw_global_desc,
|
||||
p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
in_chwn_block_desc,
|
||||
p_in_block,
|
||||
in_nchw_block_desc.GetLengths(),
|
||||
reorder_chwn_from_nchw);
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
// weight: global mem to LDS,
|
||||
blockwise_4d_tensor_copy<BlockSize>(
|
||||
wei_srck_global_desc,
|
||||
p_wei_global +
|
||||
wei_srck_global_desc.Get1dIndex(0, 0, c_block_data_begin, k_block_data_begin),
|
||||
wei_srck_block_desc,
|
||||
p_wei_block,
|
||||
wei_srck_block_desc.GetLengths());
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// a series of batched GEMM
|
||||
for(unsigned s = 0; s < S; ++s)
|
||||
{
|
||||
for(unsigned r = 0; r < R; ++r)
|
||||
{
|
||||
blockwise_batch_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
p_in_block + in_chwn_block_desc.Get1dIndex(0, 0, r, 0),
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto matrix_c_index =
|
||||
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
|
||||
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
|
||||
const unsigned k_thread_data_begin = matrix_c_index.col_begin;
|
||||
const unsigned wo_thread_data_begin = matrix_c_index.row_begin / NPerThread;
|
||||
|
||||
// output: register to global mem,
|
||||
// convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo]
|
||||
constexpr auto reorder_nkhw_from_hkwn = Sequence<3, 1, 0, 2>{};
|
||||
|
||||
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(
|
||||
out_hkwn_thread_desc,
|
||||
p_out_thread,
|
||||
out_nkhw_global_desc,
|
||||
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_hkwn_thread_desc.GetLengths(),
|
||||
reorder_nkhw_from_hkwn);
|
||||
}
|
||||
@@ -101,10 +101,10 @@ __device__ void threadwise_direct_convolution_2(InDesc,
|
||||
Float p_wei_reg[wei_reg_desc.GetElementSpace()];
|
||||
|
||||
// copy input tensor into register
|
||||
threadwise_4d_tensor_copy(in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc);
|
||||
threadwise_4d_tensor_copy(in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc.GetLengths());
|
||||
|
||||
// copy input tensor into register
|
||||
threadwise_4d_tensor_copy(wei_desc, p_wei, wei_reg_desc, p_wei_reg, wei_reg_desc);
|
||||
threadwise_4d_tensor_copy(wei_desc, p_wei, wei_reg_desc, p_wei_reg, wei_reg_desc.GetLengths());
|
||||
|
||||
// do convolution
|
||||
threadwise_direct_convolution_1(
|
||||
@@ -159,14 +159,14 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
p_in + in_desc.Get1dIndex(0, 0, s, 0),
|
||||
in_reg_desc,
|
||||
p_in_reg,
|
||||
in_reg_desc);
|
||||
in_reg_desc.GetLengths());
|
||||
|
||||
// read first 1x1 weight
|
||||
threadwise_4d_tensor_copy(wei_desc,
|
||||
p_wei + wei_desc.Get1dIndex(0, 0, s, 0),
|
||||
wei_reg_desc,
|
||||
p_wei_reg,
|
||||
wei_reg_desc);
|
||||
wei_reg_desc.GetLengths());
|
||||
|
||||
// do first 1x1 conv
|
||||
threadwise_direct_convolution_1(
|
||||
@@ -180,7 +180,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
p_wei + wei_desc.Get1dIndex(0, 0, s, r),
|
||||
wei_reg_desc,
|
||||
p_wei_reg,
|
||||
wei_reg_desc);
|
||||
wei_reg_desc.GetLengths());
|
||||
|
||||
// shift old input to the left
|
||||
threadwise_4d_tensor_shift_down(in_reg_desc, p_in_reg, I3, Number<in_w_new_read>{});
|
||||
@@ -192,7 +192,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
in_reg_desc,
|
||||
p_in_reg +
|
||||
in_reg_desc.Get1dIndex(0, 0, 0, in_reg_desc.GetLength(I3) - in_w_new_read),
|
||||
in_desc_reg_new_read);
|
||||
in_desc_reg_new_read.GetLengths());
|
||||
|
||||
// do 1x1 conv
|
||||
threadwise_direct_convolution_1(
|
||||
@@ -211,11 +211,14 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
p_wei + wei_desc.Get1dIndex(0, 0, s, r),
|
||||
wei_reg_desc,
|
||||
p_wei_reg,
|
||||
wei_reg_desc);
|
||||
wei_reg_desc.GetLengths());
|
||||
|
||||
// read new input
|
||||
threadwise_4d_tensor_copy(
|
||||
in_desc, p_in + in_desc.Get1dIndex(0, 0, s, r), in_reg_desc, p_in_reg, in_reg_desc);
|
||||
threadwise_4d_tensor_copy(in_desc,
|
||||
p_in + in_desc.Get1dIndex(0, 0, s, r),
|
||||
in_reg_desc,
|
||||
p_in_reg,
|
||||
in_reg_desc.GetLengths());
|
||||
|
||||
// do 1x1 conv
|
||||
threadwise_direct_convolution_1(
|
||||
|
||||
@@ -37,29 +37,34 @@ __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __re
|
||||
|
||||
// TODO: in order to optimize mem access for different mem type,
|
||||
// need to write specialized version
|
||||
template <class Float, class SrcDesc, class DstDesc, class RefDesc, class Reorder, class F>
|
||||
__device__ void
|
||||
threadwise_4d_tensor_pointwise_operation_binary_reorder(SrcDesc,
|
||||
Float* const __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
RefDesc,
|
||||
Reorder,
|
||||
F f)
|
||||
template <class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class DstFromSrcReorder,
|
||||
class F>
|
||||
__device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
SrcDesc,
|
||||
Float* const __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
DstFromSrcReorder,
|
||||
F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr unsigned IT0 = Reorder{}.Get(I0);
|
||||
constexpr unsigned IT1 = Reorder{}.Get(I1);
|
||||
constexpr unsigned IT2 = Reorder{}.Get(I2);
|
||||
constexpr unsigned IT3 = Reorder{}.Get(I3);
|
||||
constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0);
|
||||
constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1);
|
||||
constexpr unsigned IR2 = DstFromSrcReorder{}.Get(I2);
|
||||
constexpr unsigned IR3 = DstFromSrcReorder{}.Get(I3);
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = RefDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
|
||||
|
||||
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
@@ -74,7 +79,7 @@ threadwise_4d_tensor_pointwise_operation_binary_reorder(SrcDesc,
|
||||
const unsigned did[4] = {did0, did1, did2, did3};
|
||||
|
||||
const unsigned bindex =
|
||||
dst_desc.Get1dIndex(did[IT0], did[IT1], did[IT2], did[IT3]);
|
||||
dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
|
||||
|
||||
f(p_src[aindex], p_dst[bindex]);
|
||||
}
|
||||
@@ -92,29 +97,29 @@ __device__ void threadwise_4d_tensor_set_zero(Desc, Float* __restrict__ p)
|
||||
Desc{}, p, f_set_zero);
|
||||
}
|
||||
|
||||
template <class Float, class SrcDesc, class DstDesc, class RefDesc, class Reorder>
|
||||
__device__ void threadwise_4d_tensor_copy_reorder(
|
||||
SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, RefDesc, Reorder)
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, class DstFromSrcReorder>
|
||||
__device__ void
|
||||
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
Float* const __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
DstFromSrcReorder)
|
||||
{
|
||||
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
|
||||
|
||||
threadwise_4d_tensor_pointwise_operation_binary_reorder<Float,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
RefDesc,
|
||||
Reorder,
|
||||
decltype(f_copy)>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, RefDesc{}, Reorder{}, f_copy);
|
||||
threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
|
||||
}
|
||||
|
||||
template <class Float, class SrcDesc, class DstDesc, class RefDesc>
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
__device__ void threadwise_4d_tensor_copy(
|
||||
SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, RefDesc)
|
||||
SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths)
|
||||
{
|
||||
auto reorder = Sequence<0, 1, 2, 3>{};
|
||||
auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{};
|
||||
|
||||
threadwise_4d_tensor_copy_reorder<Float, SrcDesc, DstDesc, RefDesc, decltype(reorder)>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, RefDesc{}, reorder);
|
||||
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
|
||||
}
|
||||
|
||||
template <class Float, class Desc, class IDim, class NShift>
|
||||
|
||||
Reference in New Issue
Block a user