mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
clean up
This commit is contained in:
@@ -92,47 +92,44 @@ __device__ void blockwise_convolution(InDesc,
|
||||
auto f_copy = [](const TFloat& src, TFloat& dst) { dst = src; };
|
||||
|
||||
// copy input tensor into register
|
||||
threadwise_4d_tensor_op<TFloat,
|
||||
decltype(in_thread_src_desc),
|
||||
decltype(in_thread_dst_desc),
|
||||
decltype(f_copy)>(
|
||||
threadwise_4d_tensor_op_in<TFloat,
|
||||
decltype(in_thread_src_desc),
|
||||
decltype(in_thread_dst_desc),
|
||||
decltype(f_copy)>(
|
||||
in_thread_src_desc,
|
||||
p_in + in_desc.Get1dIndex(
|
||||
n_thread_work_begin, 0, hi_thread_work_begin, wi_thread_work_begin),
|
||||
in_thread_dst_desc,
|
||||
p_in_thread,
|
||||
f_copy,
|
||||
false);
|
||||
f_copy);
|
||||
|
||||
for(unsigned k_thread_work_begin = 0; k_thread_work_begin < KPerBlock;
|
||||
++k_thread_work_begin)
|
||||
{
|
||||
// copy weight tensor into register
|
||||
threadwise_4d_tensor_op<TFloat,
|
||||
decltype(wei_thread_src_desc),
|
||||
decltype(wei_thread_dst_desc),
|
||||
decltype(f_copy)>(
|
||||
threadwise_4d_tensor_op_wei<TFloat,
|
||||
decltype(wei_thread_src_desc),
|
||||
decltype(wei_thread_dst_desc),
|
||||
decltype(f_copy)>(
|
||||
wei_thread_src_desc,
|
||||
p_wei + wei_desc.Get1dIndex(k_thread_work_begin, 0, 0, 0),
|
||||
wei_thread_dst_desc,
|
||||
p_wei_thread,
|
||||
f_copy,
|
||||
false);
|
||||
f_copy);
|
||||
|
||||
// copy output tensor into register
|
||||
threadwise_4d_tensor_op<TFloat,
|
||||
decltype(out_thread_src_desc),
|
||||
decltype(out_thread_dst_desc),
|
||||
decltype(f_copy)>(out_thread_src_desc,
|
||||
p_out +
|
||||
out_desc.Get1dIndex(n_thread_work_begin,
|
||||
k_thread_work_begin,
|
||||
ho_thread_work_begin,
|
||||
wo_thread_work_begin),
|
||||
out_thread_dst_desc,
|
||||
p_out_thread,
|
||||
f_copy,
|
||||
false);
|
||||
threadwise_4d_tensor_op_out<TFloat,
|
||||
decltype(out_thread_src_desc),
|
||||
decltype(out_thread_dst_desc),
|
||||
decltype(f_copy)>(
|
||||
out_thread_src_desc,
|
||||
p_out + out_desc.Get1dIndex(n_thread_work_begin,
|
||||
k_thread_work_begin,
|
||||
ho_thread_work_begin,
|
||||
wo_thread_work_begin),
|
||||
out_thread_dst_desc,
|
||||
p_out_thread,
|
||||
f_copy);
|
||||
|
||||
// threadwise convolution
|
||||
threadwise_direct_convolution<TFloat,
|
||||
@@ -146,19 +143,18 @@ __device__ void blockwise_convolution(InDesc,
|
||||
p_out_thread);
|
||||
|
||||
// accumulate output tensor into LDS
|
||||
threadwise_4d_tensor_op<TFloat,
|
||||
decltype(out_thread_dst_desc),
|
||||
decltype(out_thread_src_desc),
|
||||
decltype(f_copy)>(out_thread_dst_desc,
|
||||
p_out_thread,
|
||||
out_thread_src_desc,
|
||||
p_out +
|
||||
out_desc.Get1dIndex(n_thread_work_begin,
|
||||
k_thread_work_begin,
|
||||
ho_thread_work_begin,
|
||||
wo_thread_work_begin),
|
||||
f_copy,
|
||||
false);
|
||||
threadwise_4d_tensor_op_out<TFloat,
|
||||
decltype(out_thread_dst_desc),
|
||||
decltype(out_thread_src_desc),
|
||||
decltype(f_copy)>(
|
||||
out_thread_dst_desc,
|
||||
p_out_thread,
|
||||
out_thread_src_desc,
|
||||
p_out + out_desc.Get1dIndex(n_thread_work_begin,
|
||||
k_thread_work_begin,
|
||||
ho_thread_work_begin,
|
||||
wo_thread_work_begin),
|
||||
f_copy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user