mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
refactor
This commit is contained in:
@@ -41,8 +41,8 @@ __global__ void gridwise_winograd_convolution(InGlobalDesc,
|
||||
constexpr unsigned HoPerBlock = OutTileSizeH * YPerBlock;
|
||||
constexpr unsigned WoPerBlock = OutTileSizeW * XPerBlock;
|
||||
|
||||
constexpr unsigned HiPerBlock = YPerBlock * OutTileSizeH + S - 1;
|
||||
constexpr unsigned WiPerBlock = XPerBlock * OutTileSizeW + R - 1;
|
||||
constexpr unsigned HiPerBlock = HoPerBlock + S - 1;
|
||||
constexpr unsigned WiPerBlock = WoPerBlock + R - 1;
|
||||
|
||||
constexpr unsigned InTileSizeH = OutTileSizeH + S - 1;
|
||||
constexpr unsigned InTileSizeW = OutTileSizeW + R - 1;
|
||||
@@ -102,11 +102,8 @@ __global__ void gridwise_winograd_convolution(InGlobalDesc,
|
||||
constexpr auto wei_transform_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<KPerBlock, CPerBlock, InTileSizeH, InTileSizeW>{});
|
||||
|
||||
constexpr unsigned in_transform_block_size = in_transform_block_desc.GetElementSpace();
|
||||
constexpr unsigned wei_transform_block_size = wei_transform_block_desc.GetElementSpace();
|
||||
|
||||
__shared__ TFloat p_in_transform_block[in_transform_block_size];
|
||||
__shared__ TFloat p_wei_transform_block[wei_transform_block_size];
|
||||
__shared__ TFloat p_in_transform_block[in_transform_block_desc.GetElementSpace()];
|
||||
__shared__ TFloat p_wei_transform_block[wei_transform_block_desc.GetElementSpace()];
|
||||
|
||||
// thread data
|
||||
constexpr auto in_transform_thread_block_desc =
|
||||
@@ -126,11 +123,8 @@ __global__ void gridwise_winograd_convolution(InGlobalDesc,
|
||||
constexpr auto out_thread_global_desc =
|
||||
make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_global_desc.GetStrides());
|
||||
|
||||
constexpr unsigned out_transform_thread_size = out_transform_thread_desc.GetElementSpace();
|
||||
constexpr unsigned out_thread_size = out_thread_desc.GetElementSpace();
|
||||
|
||||
TFloat p_out_transform_thread[out_transform_thread_size];
|
||||
TFloat p_out_thread[out_thread_size];
|
||||
TFloat p_out_transform_thread[out_transform_thread_desc.GetElementSpace()];
|
||||
TFloat p_out_thread[out_thread_desc.GetElementSpace()];
|
||||
|
||||
#if 0
|
||||
if(blockIdx.x == 0 && threadIdx.x == 0)
|
||||
|
||||
@@ -116,10 +116,13 @@ __device__ void threadwise_4d_tensor_shift_down(Desc, TFloat* __restrict__ p, ID
|
||||
|
||||
const unsigned did0_end =
|
||||
is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - shift : desc.GetLength(I0);
|
||||
|
||||
const unsigned did1_end =
|
||||
is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - shift : desc.GetLength(I1);
|
||||
|
||||
const unsigned did2_end =
|
||||
is_same<decltype(I2), IDim>::value ? desc.GetLength(I2) - shift : desc.GetLength(I2);
|
||||
|
||||
const unsigned did3_end =
|
||||
is_same<decltype(I3), IDim>::value ? desc.GetLength(I3) - shift : desc.GetLength(I3);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user