mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
bug fix
This commit is contained in:
@@ -94,7 +94,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
constexpr auto out_hkwn_thread_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<HoPerThread, KPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_nchw_block_desc, "in_nchw_block_desc");
|
||||
@@ -156,7 +156,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1);
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
#if 1
|
||||
// input: global mem to LDS,
|
||||
// convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N]
|
||||
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
|
||||
@@ -169,9 +168,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
p_in_block,
|
||||
in_nchw_block_desc.GetLengths(),
|
||||
reorder_chwn_from_nchw);
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
// weight: global mem to LDS,
|
||||
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
|
||||
wei_kcsr_global_desc,
|
||||
@@ -181,11 +178,9 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
p_wei_block,
|
||||
wei_kcsr_block_desc.GetLengths(),
|
||||
reorder_srck_from_kcsr);
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#if 1
|
||||
// a series of batched GEMM
|
||||
for(unsigned s = 0; s < S; ++s)
|
||||
{
|
||||
@@ -194,12 +189,11 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
|
||||
|
||||
blockwise_batch_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
p_in_block + in_chwn_block_desc.Get1dIndex(0, 0, r, 0),
|
||||
p_in_block + in_chwn_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
const auto matrix_c_index =
|
||||
|
||||
Reference in New Issue
Block a user