mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
another version of direct conv
This commit is contained in:
@@ -46,15 +46,11 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
|
||||
constexpr unsigned HiPerBlock = YPerBlock * OutTileSizeH + S - 1;
|
||||
constexpr unsigned WiPerBlock = XPerBlock * OutTileSizeW + R - 1;
|
||||
|
||||
constexpr auto in_block_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, in_global_desc.GetStrides());
|
||||
constexpr auto in_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{});
|
||||
|
||||
constexpr auto wei_block_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerBlock, CPerBlock, S, R>{}, wei_global_desc.GetStrides());
|
||||
|
||||
constexpr auto in_block_desc = make_ConstantTensorDescriptor(in_block_global_desc.GetLengths());
|
||||
constexpr auto wei_block_desc =
|
||||
make_ConstantTensorDescriptor(wei_block_global_desc.GetLengths());
|
||||
make_ConstantTensorDescriptor(Sequence<KPerBlock, CPerBlock, S, R>{});
|
||||
|
||||
// shared mem
|
||||
constexpr unsigned in_block_size = in_block_desc.GetElementSpace();
|
||||
@@ -67,30 +63,19 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
|
||||
constexpr unsigned InTileSizeH = OutTileSizeH + S - 1;
|
||||
constexpr unsigned InTileSizeW = OutTileSizeW + R - 1;
|
||||
|
||||
constexpr auto in_thread_block_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<NPerThread, CPerThread, InTileSizeH, InTileSizeW>{}, in_block_desc.GetStrides());
|
||||
|
||||
constexpr auto wei_thread_block_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, CPerThread, S, R>{}, wei_block_desc.GetStrides());
|
||||
|
||||
constexpr auto in_thread_desc =
|
||||
make_ConstantTensorDescriptor(in_thread_block_desc.GetLengths());
|
||||
make_ConstantTensorDescriptor(Sequence<NPerThread, CPerThread, InTileSizeH, InTileSizeW>{});
|
||||
|
||||
constexpr auto wei_thread_desc =
|
||||
make_ConstantTensorDescriptor(wei_thread_block_desc.GetLengths());
|
||||
make_ConstantTensorDescriptor(Sequence<KPerThread, CPerThread, S, R>{});
|
||||
|
||||
constexpr auto out_thread_desc =
|
||||
get_output_4d_tensor_descriptor(in_thread_desc, wei_thread_desc);
|
||||
|
||||
constexpr auto out_thread_global_desc =
|
||||
make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_global_desc.GetStrides());
|
||||
|
||||
// register
|
||||
constexpr unsigned in_thread_size = in_thread_desc.GetElementSpace();
|
||||
constexpr unsigned wei_thread_size = wei_thread_desc.GetElementSpace();
|
||||
constexpr unsigned out_thread_size = out_thread_desc.GetElementSpace();
|
||||
|
||||
TFloat p_in_thread[in_thread_size];
|
||||
TFloat p_wei_thread[wei_thread_size];
|
||||
TFloat p_out_thread[out_thread_size];
|
||||
TFloat p_in_thread[in_thread_desc.GetElementSpace()];
|
||||
TFloat p_wei_thread[wei_thread_desc.GetElementSpace()];
|
||||
TFloat p_out_thread[out_thread_desc.GetElementSpace()];
|
||||
|
||||
// divide block work
|
||||
constexpr unsigned NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
|
||||
@@ -169,54 +154,60 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
|
||||
{
|
||||
// copy input tensor to LDS
|
||||
blockwise_4d_tensor_copy<TFloat,
|
||||
decltype(in_block_global_desc),
|
||||
decltype(in_global_desc),
|
||||
decltype(in_block_desc),
|
||||
BlockSize>(in_block_global_desc,
|
||||
decltype(in_block_desc),
|
||||
BlockSize>(in_global_desc,
|
||||
p_in_global +
|
||||
in_global_desc.Get1dIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
in_block_desc,
|
||||
p_in_block);
|
||||
p_in_block,
|
||||
in_block_desc);
|
||||
|
||||
// copy weight tensor to LDS
|
||||
blockwise_4d_tensor_copy<TFloat,
|
||||
decltype(wei_block_global_desc),
|
||||
decltype(wei_global_desc),
|
||||
decltype(wei_block_desc),
|
||||
decltype(wei_block_desc),
|
||||
BlockSize>(
|
||||
wei_block_global_desc,
|
||||
wei_global_desc,
|
||||
p_wei_global + wei_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
wei_block_desc,
|
||||
p_wei_block);
|
||||
p_wei_block,
|
||||
wei_block_desc);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
|
||||
{
|
||||
// copy input tensor into register
|
||||
threadwise_4d_tensor_copy(in_thread_block_desc,
|
||||
threadwise_4d_tensor_copy(in_block_desc,
|
||||
p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
in_thread_desc,
|
||||
p_in_thread);
|
||||
p_in_thread,
|
||||
in_thread_desc);
|
||||
|
||||
// copy weight tensor into register
|
||||
threadwise_4d_tensor_copy(
|
||||
wei_thread_block_desc,
|
||||
wei_block_desc,
|
||||
p_wei_block + wei_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
|
||||
wei_thread_desc,
|
||||
p_wei_thread);
|
||||
p_wei_thread,
|
||||
wei_thread_desc);
|
||||
|
||||
// threadwise convolution
|
||||
threadwise_direct_convolution(in_thread_desc,
|
||||
p_in_thread,
|
||||
wei_thread_desc,
|
||||
p_wei_thread,
|
||||
out_thread_desc,
|
||||
p_out_thread);
|
||||
threadwise_direct_convolution_1(in_thread_desc,
|
||||
p_in_thread,
|
||||
wei_thread_desc,
|
||||
p_wei_thread,
|
||||
out_thread_desc,
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -224,9 +215,10 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
|
||||
threadwise_4d_tensor_copy(
|
||||
out_thread_desc,
|
||||
p_out_thread,
|
||||
out_thread_global_desc,
|
||||
out_global_desc,
|
||||
p_out_global + out_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin));
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_thread_desc);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user