refactor direct

[ROCm/composable_kernel commit: 24d2f034fa]
This commit is contained in:
Chao Liu
2018-11-25 01:10:11 -06:00
parent 40ddf8c139
commit 7569eeaf55
14 changed files with 253 additions and 1291 deletions

View File

@@ -17,10 +17,10 @@ __device__ void blockwise_convolution(InBlockDesc,
OutBlockDesc,
TFloat* __restrict__ p_out_block)
{
constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{};
constexpr auto I2 = Index<2>{};
constexpr auto I3 = Index<3>{};
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_block_desc = InBlockDesc{};
constexpr auto wei_block_desc = WeiBlockDesc{};
@@ -88,72 +88,50 @@ __device__ void blockwise_convolution(InBlockDesc,
TFloat p_wei_thread[wei_thread_src_desc.GetElementSpace()];
TFloat p_out_thread[out_thread_src_desc.GetElementSpace()];
auto f_copy = [](const TFloat& src, TFloat& dst) { dst = src; };
// copy input tensor into register
threadwise_4d_tensor_op_binary<TFloat,
decltype(in_thread_src_desc),
decltype(in_thread_dst_desc),
decltype(f_copy)>(
threadwise_4d_tensor_copy(
in_thread_src_desc,
p_in_block + in_block_desc.Get1dIndex(
n_thread_work_begin, 0, hi_thread_work_begin, wi_thread_work_begin),
in_thread_dst_desc,
p_in_thread,
f_copy);
p_in_thread);
for(unsigned k_thread_work_begin = 0; k_thread_work_begin < KPerBlock;
++k_thread_work_begin)
{
// copy weight tensor into register
threadwise_4d_tensor_op_binary<TFloat,
decltype(wei_thread_src_desc),
decltype(wei_thread_dst_desc),
decltype(f_copy)>(
wei_thread_src_desc,
p_wei_block + wei_block_desc.Get1dIndex(k_thread_work_begin, 0, 0, 0),
wei_thread_dst_desc,
p_wei_thread,
f_copy);
threadwise_4d_tensor_copy(wei_thread_src_desc,
p_wei_block +
wei_block_desc.Get1dIndex(k_thread_work_begin, 0, 0, 0),
wei_thread_dst_desc,
p_wei_thread);
// copy output tensor into register
threadwise_4d_tensor_op_binary<TFloat,
decltype(out_thread_src_desc),
decltype(out_thread_dst_desc),
decltype(f_copy)>(
out_thread_src_desc,
p_out_block + out_block_desc.Get1dIndex(n_thread_work_begin,
k_thread_work_begin,
ho_thread_work_begin,
wo_thread_work_begin),
out_thread_dst_desc,
p_out_thread,
f_copy);
threadwise_4d_tensor_copy(out_thread_src_desc,
p_out_block + out_block_desc.Get1dIndex(n_thread_work_begin,
k_thread_work_begin,
ho_thread_work_begin,
wo_thread_work_begin),
out_thread_dst_desc,
p_out_thread);
// threadwise convolution
threadwise_direct_convolution<TFloat,
decltype(in_thread_dst_desc),
decltype(wei_thread_dst_desc),
decltype(out_thread_dst_desc)>(in_thread_dst_desc,
p_in_thread,
wei_thread_dst_desc,
p_wei_thread,
out_thread_dst_desc,
p_out_thread);
threadwise_direct_convolution(in_thread_dst_desc,
p_in_thread,
wei_thread_dst_desc,
p_wei_thread,
out_thread_dst_desc,
p_out_thread);
// accumulate output tensor into LDS
threadwise_4d_tensor_op_binary<TFloat,
decltype(out_thread_dst_desc),
decltype(out_thread_src_desc),
decltype(f_copy)>(
out_thread_dst_desc,
p_out_thread,
out_thread_src_desc,
p_out_block + out_block_desc.Get1dIndex(n_thread_work_begin,
k_thread_work_begin,
ho_thread_work_begin,
wo_thread_work_begin),
f_copy);
threadwise_4d_tensor_copy(out_thread_dst_desc,
p_out_thread,
out_thread_src_desc,
p_out_block +
out_block_desc.Get1dIndex(n_thread_work_begin,
k_thread_work_begin,
ho_thread_work_begin,
wo_thread_work_begin));
}
}
}