mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
another version of direct conv
This commit is contained in:
@@ -130,6 +130,7 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc,
|
||||
blockwise_4d_tensor_copy<TFloat,
|
||||
decltype(in_block_src_desc),
|
||||
decltype(in_block_desc),
|
||||
decltype(in_block_desc),
|
||||
BlockSize>(in_block_src_desc,
|
||||
p_in_global +
|
||||
in_global_desc.Get1dIndex(n_block_work_begin,
|
||||
@@ -137,17 +138,20 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc,
|
||||
hi_block_work_begin,
|
||||
wi_block_work_begin),
|
||||
in_block_desc,
|
||||
p_in_block);
|
||||
p_in_block,
|
||||
in_block_desc);
|
||||
|
||||
// copy weight tensor to LDS
|
||||
blockwise_4d_tensor_copy<TFloat,
|
||||
decltype(wei_block_src_desc),
|
||||
decltype(wei_block_desc),
|
||||
decltype(wei_block_desc),
|
||||
BlockSize>(
|
||||
wei_block_src_desc,
|
||||
p_wei_global + wei_global_desc.Get1dIndex(k_block_work_begin, c_block_work_begin, 0, 0),
|
||||
wei_block_desc,
|
||||
p_wei_block);
|
||||
p_wei_block,
|
||||
wei_block_desc);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -171,11 +175,13 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc,
|
||||
blockwise_4d_tensor_copy<TFloat,
|
||||
decltype(out_block_desc),
|
||||
decltype(out_block_src_desc),
|
||||
decltype(out_block_desc),
|
||||
BlockSize>(
|
||||
out_block_desc,
|
||||
p_out_block,
|
||||
out_block_src_desc,
|
||||
p_out_global +
|
||||
out_global_desc.Get1dIndex(
|
||||
n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin));
|
||||
n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin),
|
||||
out_block_desc);
|
||||
}
|
||||
Reference in New Issue
Block a user