mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
adding implicit gemm
This commit is contained in:
@@ -35,9 +35,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto True = Constant<bool, true>;
|
||||
constexpr auto False = Constant<bool, false>;
|
||||
|
||||
constexpr auto in_nchw_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_kcsr_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_nkhw_global_desc = OutGlobalDesc{};
|
||||
@@ -48,13 +45,20 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
constexpr unsigned HiPerBlock = HoPerBlock + S - 1;
|
||||
constexpr unsigned WiPerBlock = WoPerBlock + R - 1;
|
||||
|
||||
// block
|
||||
// tensor view of blockwise input and weight in LDS
|
||||
constexpr auto in_chwn_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{});
|
||||
|
||||
constexpr auto wei_srck_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<S, R, CPerBlock, KPerBlock>{});
|
||||
|
||||
// matrix view of blockwise input and weight in LDS
|
||||
constexpr auto in_cxhwn_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>, Number<HiPerBlock * WiPerBlock * NPerBlock>);
|
||||
|
||||
constexpr auto wei_srcxk_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<S * R * CPerBlock>, Number<KPerBlock>);
|
||||
|
||||
// LDS
|
||||
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace();
|
||||
constexpr unsigned wei_block_size = wei_srck_block_desc.GetElementSpace();
|
||||
@@ -62,8 +66,38 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
__shared__ Float p_in_block[in_block_size];
|
||||
__shared__ Float p_wei_block[wei_block_size];
|
||||
|
||||
// thread
|
||||
constexpr auto out_hkwn_thread_desc = xxxxxx();
|
||||
// a series of batched GEMM
|
||||
// blockwise batched GEMM, C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, c_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_matrix[S*R*C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_matrix[C,Hi*Wi*N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_matrix[Ho*K,Wo*N]
|
||||
constexpr auto a_block_mtx_desc =
|
||||
wei_srcxk_block_mtx_desc.MakeSubMatrixDescriptor(Number<CPerBlock>{}, Number<KPerBlock>{});
|
||||
|
||||
constexpr auto b_block_mtx_desc = in_cxhwn_block_mtx_desc.MakeSubMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<WoPerBlock * NPerBlock>{});
|
||||
|
||||
auto f_accum = (auto& c, auto& v) { c += v; };
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize,
|
||||
a_block_mtx_desc,
|
||||
b_block_mtx_desc,
|
||||
true,
|
||||
false,
|
||||
HoPerBlock,
|
||||
0,
|
||||
xxx_b_matrix_stride,
|
||||
HoPerThread,
|
||||
KPerThread,
|
||||
NPerThread * WoPerThread,
|
||||
CPerTrhead,
|
||||
decltype(f_accum)>{};
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_hkwn_thread_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<HoPerThread, KPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()];
|
||||
@@ -85,14 +119,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
in_chwn_block_desc,
|
||||
reorder_nchw2chwn);
|
||||
|
||||
// matrix view of input
|
||||
constexpr unsigned in_row = in_chwn_block_desc.GetLength(I0);
|
||||
constexpr unsigned in_col = in_chwn_block_desc.GetLength(I1) *
|
||||
in_chwn_block_desc.GetLength(I2) *
|
||||
in_chwn_block_desc.GetLength(I3);
|
||||
constexpr auto in_cxhwn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<in_row>, Number<in_col>, Number<in_col>);
|
||||
|
||||
// weight: global mem to LDS,
|
||||
// convert 4d-tensor wei[K,C,S,R] to matrix wei_matrix[S*R*C,K]
|
||||
constexpr auto reorder_kcsr2srck = Sequence<3, 2, 0, 1>{};
|
||||
@@ -104,44 +130,8 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
wei_csrk_block_desc,
|
||||
reorder_kcsr2csrk);
|
||||
|
||||
// matrix view of wei
|
||||
constexpr unsigned wei_row = wei_srck_block_desc.GetLength(I0) *
|
||||
wei_srck_block_desc.GetLength(I1) *
|
||||
wei_srck_block_desc.GetLength(I2);
|
||||
constexpr unsigned wei_col = wei_srck_block_desc.GetLength(I3);
|
||||
constexpr auto wei_srcxk_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<wei_row>, Number<wei_col>, Number<wei_col>);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// a series of batched GEMM
|
||||
// blockwise batched GEMM, C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, c_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_matrix[S*R*C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_matrix[C,Hi*Wi*N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_matrix[Ho*K,Wo*N]
|
||||
constexpr auto a_block_mtx_desc = wei_srcxk_block_mtx_desc.MakeSubMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{});
|
||||
|
||||
constexpr auto b_block_mtx_desc = in_cxhwn_block_mtx_desc.MakeSubMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<WoPerBlock * NPerBlock>{});
|
||||
|
||||
auto f_accum = (auto& c, auto& v) { c += v; };
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize,
|
||||
a_block_mtx_desc,
|
||||
b_block_mtx_desc,
|
||||
true,
|
||||
false,
|
||||
HoPerBlock,
|
||||
0,
|
||||
xxx_b_matrix_stride,
|
||||
HoPerThread,
|
||||
KPerThread,
|
||||
NPerThread * WoPerThread,
|
||||
CPerTrhead,
|
||||
decltype(f_accum)>{};
|
||||
// loop over filter point
|
||||
for(unsigned s = 0; s < S; ++s)
|
||||
{
|
||||
@@ -165,6 +155,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
// output: register to global mem,
|
||||
// convert matrix out_matrix[Ho*K,Wo*N] to 4d-tensor out[N,K,Ho,Wo]
|
||||
constexpr auto reorder_hkwn2nkhw = Sequence<2, 1, 3, 0>{};
|
||||
|
||||
threadwise_4d_tensor_copy_reorder(
|
||||
out_hkwn_thread_desc,
|
||||
p_out_thread,
|
||||
|
||||
Reference in New Issue
Block a user