mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
refactor
This commit is contained in:
@@ -336,14 +336,6 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
int main()
|
||||
{
|
||||
#if 0
|
||||
constexpr unsigned N = 1;
|
||||
constexpr unsigned C = 1;
|
||||
constexpr unsigned HI = 4;
|
||||
constexpr unsigned WI = 4;
|
||||
constexpr unsigned K = 1;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
#elif 0
|
||||
constexpr unsigned N = 1;
|
||||
constexpr unsigned C = 1;
|
||||
constexpr unsigned HI = 34;
|
||||
@@ -352,13 +344,13 @@ int main()
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
#elif 1
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 34;
|
||||
constexpr unsigned WI = 34;
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
#elif 0
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 64;
|
||||
@@ -369,12 +361,12 @@ int main()
|
||||
constexpr unsigned R = 3;
|
||||
#elif 0
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 64;
|
||||
constexpr unsigned HI = 66;
|
||||
constexpr unsigned WI = 66;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 36;
|
||||
constexpr unsigned WI = 36;
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
constexpr unsigned S = 5;
|
||||
constexpr unsigned R = 5;
|
||||
#endif
|
||||
|
||||
auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence<N, C, HI, WI>{});
|
||||
|
||||
@@ -52,7 +52,7 @@ void device_implicit_gemm_convolution(
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 1
|
||||
#elif 0
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
@@ -60,7 +60,7 @@ void device_implicit_gemm_convolution(
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned KPerThread = 4;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
|
||||
|
||||
@@ -152,7 +152,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc,
|
||||
for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1);
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
#if 1
|
||||
// input: global mem to LDS,
|
||||
// convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N]
|
||||
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
|
||||
@@ -165,9 +164,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc,
|
||||
p_in_block,
|
||||
in_nchw_block_desc.GetLengths(),
|
||||
reorder_chwn_from_nchw);
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
// weight: global mem to LDS,
|
||||
blockwise_4d_tensor_copy<BlockSize>(
|
||||
wei_srck_global_desc,
|
||||
@@ -176,11 +173,9 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc,
|
||||
wei_srck_block_desc,
|
||||
p_wei_block,
|
||||
wei_srck_block_desc.GetLengths());
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#if 1
|
||||
// a series of batched GEMM
|
||||
for(unsigned s = 0; s < S; ++s)
|
||||
{
|
||||
@@ -194,7 +189,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc,
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
const auto matrix_c_index =
|
||||
|
||||
Reference in New Issue
Block a user