mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
bug fix and tune implicit gemm
This commit is contained in:
@@ -105,13 +105,24 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
|
||||
|
||||
const auto c_thread_mtx_index = CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
|
||||
mMyThreadOffsetA = c_thread_mtx_index.batch_begin * a_block_mtx.GetElementSpace() +
|
||||
mMyThreadOffsetA = c_thread_mtx_index.batch_begin * BlockMatrixStrideA +
|
||||
((!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row_begin, 0)
|
||||
: a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row_begin));
|
||||
|
||||
mMyThreadOffsetB = c_thread_mtx_index.batch_begin * b_block_mtx.GetElementSpace() +
|
||||
mMyThreadOffsetB = c_thread_mtx_index.batch_begin * BlockMatrixStrideB +
|
||||
((!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col_begin)
|
||||
: b_block_mtx.Get1dIndex(c_thread_mtx_index.col_begin, 0));
|
||||
|
||||
#if 0
|
||||
printf("%u %u, %u %u %u, %u %u\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
c_thread_mtx_index.batch_begin,
|
||||
c_thread_mtx_index.row_begin,
|
||||
c_thread_mtx_index.col_begin,
|
||||
mMyThreadOffsetA,
|
||||
mMyThreadOffsetB);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const
|
||||
|
||||
@@ -174,7 +174,7 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
|
||||
for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
|
||||
{
|
||||
// threadwise convolution
|
||||
#if 1
|
||||
#if 0
|
||||
threadwise_direct_convolution_2(
|
||||
in_thread_block_desc,
|
||||
p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
|
||||
@@ -84,6 +84,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
print_ConstantTensorDescriptor(in_nchw_block_desc, "in_nchw_block_desc");
|
||||
print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(wei_kcsr_block_desc, "wei_kcsr_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_srck_block_desc, "wei_srck_block_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc");
|
||||
@@ -184,7 +185,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
in_nchw_block_desc.GetLengths());
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// weight: global mem to LDS,
|
||||
// convert [K,C,S,R] to [S,R,C,K]
|
||||
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
|
||||
@@ -209,6 +210,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#if 1
|
||||
// a series of batched GEMM
|
||||
for(unsigned s = 0; s < S; ++s)
|
||||
{
|
||||
@@ -222,16 +224,21 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
const auto matrix_c_index =
|
||||
blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
|
||||
#if 0
|
||||
printf("%u %u, %u %u %u\n",get_block_1d_id(), get_thread_local_1d_id(), matrix_c_index.batch_begin, matrix_c_index.row_begin, matrix_c_index.col_begin);
|
||||
#endif
|
||||
|
||||
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
|
||||
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerThread;
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// output: register to global mem,
|
||||
// convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo]
|
||||
constexpr auto reorder_nkhw_from_hkwn = Sequence<3, 1, 0, 2>{};
|
||||
|
||||
@@ -151,6 +151,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
in_w_new_read>{});
|
||||
|
||||
#if 0
|
||||
// this verison reused old input data in register, and read new data from LDS
|
||||
// loop over vertical direction
|
||||
for(unsigned s = 0; s < wei_desc.GetLength(I2); ++s)
|
||||
{
|
||||
@@ -200,6 +201,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
}
|
||||
}
|
||||
#elif 1
|
||||
// this version read all input from LDS when filter moves
|
||||
// loop over vertical direction
|
||||
for(unsigned s = 0; s < wei_desc.GetLength(I2); ++s)
|
||||
{
|
||||
@@ -226,4 +228,4 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user