mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
adding implicit gemm
This commit is contained in:
46
src/include/ConstantMatrixDescriptor.cuh
Normal file
46
src/include/ConstantMatrixDescriptor.cuh
Normal file
@@ -0,0 +1,46 @@
|
||||
#pragma once
|
||||
#include "common.cuh"
|
||||
|
||||
template <unsigned NRow, unsigned NCol, unsigned RowStride>
|
||||
struct ConstantMatrixDescriptor
|
||||
{
|
||||
__host__ __device__ ConstantMatrixDescriptor()
|
||||
{
|
||||
static_assert(NCol <= RowStride, "wrong! NCol > RowStride!");
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr unsigned GetNumberOfRow() const { return NRow; }
|
||||
|
||||
__host__ __device__ constexpr unsigned GetNumberOfColumn() const { return NCol; }
|
||||
|
||||
__host__ __device__ constexpr unsigned GetRowStride() const { return RowStride; }
|
||||
|
||||
__host__ __device__ constexpr unsigned GetElementSize() const { return NRow * NCol; }
|
||||
|
||||
__host__ __device__ constexpr unsigned GetElementSpace() const { return NRow * RowStride; }
|
||||
|
||||
__host__ __device__ unsigned Get1dIndex(unsigned irow, unsigned icol) const
|
||||
{
|
||||
return irow * RowStride + icol;
|
||||
}
|
||||
|
||||
template <unsigned SubNRow, unsigned SubNCol>
|
||||
__host__ __device__ constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>,
|
||||
Number<SubNCol>) const
|
||||
{
|
||||
return ConstantMatrixDescriptor<SubNRow, SubNCol, RowStride>{};
|
||||
}
|
||||
};
|
||||
|
||||
template <unsigned NRow, unsigned NCol>
|
||||
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>)
|
||||
{
|
||||
return ConstantMatrixDescriptor<NRow, NCol, NCol>{};
|
||||
}
|
||||
|
||||
template <unsigned NRow, unsigned NCol, unsigned RowStride>
|
||||
__host__ __device__ constexpr auto
|
||||
make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>, Number<RowStride>)
|
||||
{
|
||||
return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
|
||||
}
|
||||
@@ -1,70 +1,6 @@
|
||||
#pragma once
|
||||
#include "common.cuh"
|
||||
|
||||
template <class T, T N>
|
||||
struct Constant
|
||||
{
|
||||
static const T mValue = N;
|
||||
};
|
||||
|
||||
template <unsigned N>
|
||||
using Number = Constant<unsigned, N>;
|
||||
|
||||
template <unsigned... Is>
|
||||
struct Sequence
|
||||
{
|
||||
static constexpr unsigned nDim = sizeof...(Is);
|
||||
|
||||
const unsigned mData[nDim] = {Is...};
|
||||
|
||||
template <unsigned I>
|
||||
__host__ __device__ constexpr unsigned Get(Number<I>) const
|
||||
{
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
template <unsigned I0, unsigned I1>
|
||||
__host__ __device__ constexpr auto Reorder(Number<I0>, Number<I1>) const
|
||||
{
|
||||
constexpr unsigned IR0 = Get(Number<I0>{});
|
||||
constexpr unsigned IR1 = Get(Number<I1>{});
|
||||
|
||||
return Sequence<IR0, IR1>{};
|
||||
}
|
||||
|
||||
template <unsigned I0, unsigned I1, unsigned I2>
|
||||
__host__ __device__ constexpr auto Reorder(Number<I0>, Number<I1>, Number<I2>) const
|
||||
{
|
||||
constexpr unsigned IR0 = Get(Number<I0>{});
|
||||
constexpr unsigned IR1 = Get(Number<I1>{});
|
||||
constexpr unsigned IR2 = Get(Number<I2>{});
|
||||
|
||||
return Sequence<IR0, IR1, IR2>{};
|
||||
}
|
||||
|
||||
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
|
||||
__host__ __device__ constexpr auto Reorder(Number<I0>, Number<I1>, Number<I2>, Number<I3>) const
|
||||
{
|
||||
constexpr unsigned IR0 = Get(Number<I0>{});
|
||||
constexpr unsigned IR1 = Get(Number<I1>{});
|
||||
constexpr unsigned IR2 = Get(Number<I2>{});
|
||||
constexpr unsigned IR3 = Get(Number<I3>{});
|
||||
|
||||
return Sequence<IR0, IR1, IR2, IR3>{};
|
||||
}
|
||||
|
||||
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
|
||||
__host__ __device__ constexpr auto Reorder(Sequence<I0, I1, I2, I3>) const
|
||||
{
|
||||
constexpr unsigned IR0 = Get(Number<I0>{});
|
||||
constexpr unsigned IR1 = Get(Number<I1>{});
|
||||
constexpr unsigned IR2 = Get(Number<I2>{});
|
||||
constexpr unsigned IR3 = Get(Number<I3>{});
|
||||
|
||||
return Sequence<IR0, IR1, IR2, IR3>{};
|
||||
}
|
||||
};
|
||||
|
||||
template <class Lengths, class Strides>
|
||||
struct ConstantTensorDescriptor
|
||||
{
|
||||
@@ -1,5 +1,5 @@
|
||||
#pragma once
|
||||
#include "constant_tensor_descriptor.cuh"
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
#include "threadwise_tensor_op.cuh"
|
||||
#include "threadwise_direct_convolution.cuh"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#pragma once
|
||||
#include "constant_tensor_descriptor.cuh"
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
|
||||
template <unsigned BlockSize, class Float, class DstDesc, class F>
|
||||
__device__ void
|
||||
|
||||
@@ -12,4 +12,72 @@ struct is_same<T, T>
|
||||
static const bool value = true;
|
||||
};
|
||||
|
||||
__device__ unsigned get_thread_local_id() { return threadIdx.x; }
|
||||
__device__ unsigned get_thread_local_1d_id() { return threadIdx.x; }
|
||||
|
||||
__device__ unsigned get_block_1d_id() { return blockIdx.x; }
|
||||
|
||||
template <class T, T N>
|
||||
struct Constant
|
||||
{
|
||||
static const T mValue = N;
|
||||
|
||||
__host__ __device__ constexpr T Get() const { return mValue; }
|
||||
};
|
||||
|
||||
template <unsigned N>
|
||||
using Number = Constant<unsigned, N>;
|
||||
|
||||
template <unsigned... Is>
|
||||
struct Sequence
|
||||
{
|
||||
static constexpr unsigned nDim = sizeof...(Is);
|
||||
|
||||
const unsigned mData[nDim] = {Is...};
|
||||
|
||||
template <unsigned I>
|
||||
__host__ __device__ constexpr unsigned Get(Number<I>) const
|
||||
{
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
template <unsigned I0, unsigned I1>
|
||||
__host__ __device__ constexpr auto Reorder(Number<I0>, Number<I1>) const
|
||||
{
|
||||
constexpr unsigned IR0 = Get(Number<I0>{});
|
||||
constexpr unsigned IR1 = Get(Number<I1>{});
|
||||
|
||||
return Sequence<IR0, IR1>{};
|
||||
}
|
||||
|
||||
template <unsigned I0, unsigned I1, unsigned I2>
|
||||
__host__ __device__ constexpr auto Reorder(Number<I0>, Number<I1>, Number<I2>) const
|
||||
{
|
||||
constexpr unsigned IR0 = Get(Number<I0>{});
|
||||
constexpr unsigned IR1 = Get(Number<I1>{});
|
||||
constexpr unsigned IR2 = Get(Number<I2>{});
|
||||
|
||||
return Sequence<IR0, IR1, IR2>{};
|
||||
}
|
||||
|
||||
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
|
||||
__host__ __device__ constexpr auto Reorder(Number<I0>, Number<I1>, Number<I2>, Number<I3>) const
|
||||
{
|
||||
constexpr unsigned IR0 = Get(Number<I0>{});
|
||||
constexpr unsigned IR1 = Get(Number<I1>{});
|
||||
constexpr unsigned IR2 = Get(Number<I2>{});
|
||||
constexpr unsigned IR3 = Get(Number<I3>{});
|
||||
|
||||
return Sequence<IR0, IR1, IR2, IR3>{};
|
||||
}
|
||||
|
||||
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
|
||||
__host__ __device__ constexpr auto Reorder(Sequence<I0, I1, I2, I3>) const
|
||||
{
|
||||
constexpr unsigned IR0 = Get(Number<I0>{});
|
||||
constexpr unsigned IR1 = Get(Number<I1>{});
|
||||
constexpr unsigned IR2 = Get(Number<I2>{});
|
||||
constexpr unsigned IR3 = Get(Number<I3>{});
|
||||
|
||||
return Sequence<IR0, IR1, IR2, IR3>{};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#pragma once
|
||||
#include "constant_tensor_descriptor.cuh"
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <class InDesc, class WeiDesc>
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
template <class ThreadMatrixA,
|
||||
bool TransA,
|
||||
class FloatA,
|
||||
class ThreadMatrixB,
|
||||
bool TransB,
|
||||
class FloatB,
|
||||
class ThreadMatrixC,
|
||||
bool TransA,
|
||||
bool TransB,
|
||||
bool TransC,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC,
|
||||
class Accumulator>
|
||||
__device__ void threadwise_gemm(ThreadMatrixA,
|
||||
@@ -26,41 +27,51 @@ __device__ void threadwise_gemm(ThreadMatrixA,
|
||||
template <unsigned BlockSize,
|
||||
class BlockMatrixA,
|
||||
class BlockMatrixB,
|
||||
class ThreadMatrixC,
|
||||
bool TransA,
|
||||
bool TransB,
|
||||
unsigned BatchSize,
|
||||
bool TransC,
|
||||
unsigned BlockMatrixStrideA,
|
||||
unsigned BlockMatrixStrideB,
|
||||
unsigned ThreadMatrixStrideC,
|
||||
unsigned BatchSize,
|
||||
unsigned BatchPerThread,
|
||||
unsigned MPerThread,
|
||||
unsigned NPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned KPerLoop,
|
||||
class Accumulator>
|
||||
struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
|
||||
{
|
||||
unsigned mMyThreadOffsetA = 0;
|
||||
unsigned mMyThreadOffsetB = 0;
|
||||
|
||||
struct MatrixIndex
|
||||
{
|
||||
unsigned batch_begin;
|
||||
unsigned block_row_begin;
|
||||
unsigned block_col_begin;
|
||||
unsigned row_begin;
|
||||
unsigned col_begin;
|
||||
};
|
||||
|
||||
__device__ blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c()
|
||||
{
|
||||
static_assert(ThreadMatrixStrideC > 0, "wrong! ThreadMatrixStrideC == 0!");
|
||||
|
||||
constexpr auto a_block = BlockMatrixA{};
|
||||
constexpr auto b_block = BlockMatrixB{};
|
||||
#if 0
|
||||
constexpr auto a_block_desc = BlockMatrixA{};
|
||||
constexpr auto b_block_desc = BlockMatrixB{};
|
||||
|
||||
constexpr auto a_thread = ThreadMatrixA{};
|
||||
constexpr auto b_thread = ThreadMatrixB{};
|
||||
constexpr auto c_thread = ThreadMatrixC{};
|
||||
constexpr unsigned a_thread_row = (!TransA) ? MPerThread : KPerThread;
|
||||
constexpr unsigned a_thread_col = (!TransA) ? KPerThread : MPerThread;
|
||||
constexpr unsigned b_thread_row = (!TransB) ? KPerThread : NPerThread;
|
||||
constexpr unsigned b_thread_col = (!TransB) ? NPerThread : KPerThread;
|
||||
|
||||
constexpr unsigned m_block = (!TransA) ? a_block.NRow() : a_block.NCol();
|
||||
constexpr unsigned n_block = (!TransB) ? b_block.NCol() : b_block.NRow();
|
||||
constexpr auto a_thread_desc = ConstantMatrixDescriptor<a_thread_row, a_thread_col>{};
|
||||
constexpr auto b_thread_desc = ConstantMatrixDescriptor<b_thread_row, b_thread_col>{};
|
||||
constexpr auto c_thread_desc = ConstantMatrixDescriptor<MPerThread, NPerThread>{};
|
||||
|
||||
constexpr unsigned m_thread = (!TransA) ? a_thread.NRow() : a_thread.NCol();
|
||||
constexpr unsigned n_thread = (!TransB) ? b_thread.NCol() : b_thread.NRow();
|
||||
constexpr unsigned m_block = (!TransA) ? a_block_desc.NRow() : a_block_desc.NCol();
|
||||
constexpr unsigned n_block = (!TransB) ? b_block_desc.NCol() : b_block_desc.NRow();
|
||||
|
||||
constexpr unsigned m_thread = (!TransA) ? a_thread_desc.NRow() : a_thread_desc.NCol();
|
||||
constexpr unsigned n_thread = (!TransB) ? b_thread_desc.NCol() : b_thread_desc.NRow();
|
||||
|
||||
constexpr unsigned num_threads_per_row = (m_block + m_thread - 1) / m_thread;
|
||||
constexpr unsigned num_threads_per_col = (n_block + n_thread - 1) / n_thread;
|
||||
@@ -72,12 +83,17 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
|
||||
|
||||
const auto mtx_c_idnex = CalculateThreadMatrixCIndex(get_thread_local_id());
|
||||
|
||||
mMyThreadOffsetA = xxx;
|
||||
mMyThreadoffSetB = xxx;
|
||||
// mMyThreadOffsetA = xxx;
|
||||
// mMyThreadoffSetB = xxx;
|
||||
#else
|
||||
mMyThreadOffsetA = 0;
|
||||
mMyThreadOffsetB = 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const
|
||||
{
|
||||
#if 0
|
||||
constexpr auto a_block = BlockMatrixA{};
|
||||
constexpr auto b_block = BlockMatrixB{};
|
||||
constexpr auto c_block = BlockMatrixC{};
|
||||
@@ -104,6 +120,9 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
|
||||
|
||||
return MatrixIndex{
|
||||
batch_begin, thread_matrix_row_id * m_thread, thread_matrix_col_id * n_thread};
|
||||
#else
|
||||
return MatrixIndex{0, 0, 0};
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
@@ -111,8 +130,4 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
|
||||
{
|
||||
// do something
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned mMyThreadOffsetA = 0;
|
||||
unsigned mMyThreadOffsetB = 0;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#pragma once
|
||||
#include "constant_tensor_descriptor.cuh"
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
#include "blockwise_tensor_op.cuh"
|
||||
#include "blockwise_direct_convolution.cuh"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#pragma once
|
||||
#include "constant_tensor_descriptor.cuh"
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
#include "blockwise_tensor_op.cuh"
|
||||
#include "blockwise_direct_convolution.cuh"
|
||||
#include "threadwise_tensor_op.cuh"
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
#pragma once
|
||||
#include "constant_tensor_descriptor.cuh"
|
||||
#include "common.cuh"
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
#include "ConstantMatrixDescriptor.cuh"
|
||||
#include "blockwise_tensor_op.cuh"
|
||||
#include "threadwise_tensor_op.cuh"
|
||||
#include "gemm.cuh"
|
||||
|
||||
template <unsigned GridSize,
|
||||
unsigned BlockSize,
|
||||
@@ -45,19 +48,78 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
constexpr unsigned HiPerBlock = HoPerBlock + S - 1;
|
||||
constexpr unsigned WiPerBlock = WoPerBlock + R - 1;
|
||||
|
||||
// tensor view of blockwise input and weight in LDS
|
||||
constexpr auto in_chwn_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{});
|
||||
// divide block work: NCHW
|
||||
constexpr unsigned NBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
|
||||
constexpr unsigned KBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
|
||||
constexpr unsigned HBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr unsigned WBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
|
||||
|
||||
unsigned itmp = get_block_1d_id();
|
||||
const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
|
||||
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
|
||||
const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork);
|
||||
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
|
||||
const unsigned h_block_work_id = itmp / WBlockWork;
|
||||
const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork;
|
||||
|
||||
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const unsigned wo_block_data_begin = w_block_work_id * HoPerBlock;
|
||||
|
||||
const unsigned hi_block_data_begin = ho_block_data_begin;
|
||||
const unsigned wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// tensor view of blockwise input and weight in LDS
|
||||
constexpr auto wei_srck_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<S, R, CPerBlock, KPerBlock>{});
|
||||
|
||||
// matrix view of blockwise input and weight in LDS
|
||||
constexpr auto in_cxhwn_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>, Number<HiPerBlock * WiPerBlock * NPerBlock>);
|
||||
constexpr auto in_chwn_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{});
|
||||
|
||||
constexpr auto wei_srcxk_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<S * R * CPerBlock>, Number<KPerBlock>);
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_hkwn_thread_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<HoPerThread, KPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[S,R,C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N]
|
||||
const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}); // constexpr doesn't compile
|
||||
|
||||
const auto b_cxwn_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_chwn_block_desc.GetStride(I1)>{}); // constexpr doesn't compile
|
||||
|
||||
const auto c_kxwn_thread_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThread>{}, Number<WoPerThread * NPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
auto f_accum = [](auto& c, auto& ab) { c += ab; };
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxwn_block_mtx_desc),
|
||||
decltype(c_kxwn_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
in_chwn_block_desc.GetStride(I1),
|
||||
out_hkwn_thread_desc.GetStride(
|
||||
I1),
|
||||
HoPerBlock,
|
||||
HoPerThread,
|
||||
CPerThread,
|
||||
decltype(f_accum)>{};
|
||||
|
||||
// LDS
|
||||
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace();
|
||||
@@ -66,87 +128,59 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
__shared__ Float p_in_block[in_block_size];
|
||||
__shared__ Float p_wei_block[wei_block_size];
|
||||
|
||||
// a series of batched GEMM
|
||||
// blockwise batched GEMM, C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, c_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_matrix[S*R*C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_matrix[C,Hi*Wi*N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_matrix[Ho*K,Wo*N]
|
||||
constexpr auto a_block_mtx_desc =
|
||||
wei_srcxk_block_mtx_desc.MakeSubMatrixDescriptor(Number<CPerBlock>{}, Number<KPerBlock>{});
|
||||
|
||||
constexpr auto b_block_mtx_desc = in_cxhwn_block_mtx_desc.MakeSubMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<WoPerBlock * NPerBlock>{});
|
||||
|
||||
auto f_accum = (auto& c, auto& v) { c += v; };
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize,
|
||||
a_block_mtx_desc,
|
||||
b_block_mtx_desc,
|
||||
true,
|
||||
false,
|
||||
HoPerBlock,
|
||||
0,
|
||||
xxx_b_matrix_stride,
|
||||
HoPerThread,
|
||||
KPerThread,
|
||||
NPerThread * WoPerThread,
|
||||
CPerTrhead,
|
||||
decltype(f_accum)>{};
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_hkwn_thread_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<HoPerThread, KPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()];
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread);
|
||||
|
||||
for(unsigned c_block_data_begin = 0; c_block_data_begin < in_global_desc.GetLength(I1);
|
||||
for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1);
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
// input: global mem to LDS,
|
||||
// convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N]
|
||||
constexpr auto reorder_nchw2chwn = Sequence<3, 0, 1, 2>{};
|
||||
|
||||
blockwise_4d_tensor_copy_reorder<BlockSize>(in_nchw_global_desc,
|
||||
p_in_global,
|
||||
in_chwn_block_desc,
|
||||
p_in_block,
|
||||
in_chwn_block_desc,
|
||||
reorder_nchw2chwn);
|
||||
blockwise_4d_tensor_copy_reorder<BlockSize>(
|
||||
in_nchw_global_desc,
|
||||
p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
in_chwn_block_desc,
|
||||
p_in_block,
|
||||
in_chwn_block_desc,
|
||||
reorder_nchw2chwn);
|
||||
|
||||
// weight: global mem to LDS,
|
||||
// convert 4d-tensor wei[K,C,S,R] to matrix wei_matrix[S*R*C,K]
|
||||
constexpr auto reorder_kcsr2srck = Sequence<3, 2, 0, 1>{};
|
||||
|
||||
blockwise_4d_tensor_copy_reorder<BlockSize>(wei_csrk_global_desc,
|
||||
p_wei_global,
|
||||
wei_csrk_block_desc,
|
||||
p_wei_block,
|
||||
wei_csrk_block_desc,
|
||||
reorder_kcsr2csrk);
|
||||
blockwise_4d_tensor_copy_reorder<BlockSize>(
|
||||
wei_kcsr_global_desc,
|
||||
p_wei_global +
|
||||
wei_kcsr_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
wei_srck_block_desc,
|
||||
p_wei_block,
|
||||
wei_srck_block_desc,
|
||||
reorder_kcsr2srck);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// loop over filter point
|
||||
// a series of batched GEMM
|
||||
for(unsigned s = 0; s < S; ++s)
|
||||
{
|
||||
for(unsigned r = 0; r < R; ++r)
|
||||
{
|
||||
blockwise_batch_gemm.run(
|
||||
p_wei_block + wei_srcxk_block_mtx_desc.Get1dIndex(xxxxx, xxxx),
|
||||
p_in_block + in_cxhwn_block_mtx_desc.Get1dIndex(xxxx, xxxx),
|
||||
p_out_thread);
|
||||
blockwise_batch_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
p_in_block + in_chwn_block_desc.Get1dIndex(0, 0, r, 0),
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto matrix_c_index =
|
||||
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_id());
|
||||
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
|
||||
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
|
||||
const unsigned k_thread_data_begin = matrix_c_index.col_begin;
|
||||
@@ -160,10 +194,10 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
out_hkwn_thread_desc,
|
||||
p_out_thread,
|
||||
out_nkhw_global_desc,
|
||||
p_out_global + out_nkhw_global_desc.GetIndex(n_block_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_hkwn_thread_desc,
|
||||
reorder_hkwn2nkhw);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#pragma once
|
||||
#include "constant_tensor_descriptor.cuh"
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
#include "blockwise_winograd_transform.cuh"
|
||||
#include "threadwise_winograd_transform.cuh"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#pragma once
|
||||
#include "constant_tensor_descriptor.cuh"
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
|
||||
// optimized for scenario if p_in, p_wei, p_out are in register
|
||||
template <class Float, class InDesc, class WeiDesc, class OutDesc>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#pragma once
|
||||
#include "constant_tensor_descriptor.cuh"
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
|
||||
template <class Float, class Desc, class F>
|
||||
__device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __restrict__ p, F f)
|
||||
|
||||
Reference in New Issue
Block a user