mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
@@ -101,10 +101,10 @@ __device__ void threadwise_direct_convolution_2(InDesc,
|
||||
Float p_wei_reg[wei_reg_desc.GetElementSpace()];
|
||||
|
||||
// copy input tensor into register
|
||||
threadwise_4d_tensor_copy(in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc);
|
||||
threadwise_4d_tensor_copy(in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc.GetLengths());
|
||||
|
||||
// copy input tensor into register
|
||||
threadwise_4d_tensor_copy(wei_desc, p_wei, wei_reg_desc, p_wei_reg, wei_reg_desc);
|
||||
threadwise_4d_tensor_copy(wei_desc, p_wei, wei_reg_desc, p_wei_reg, wei_reg_desc.GetLengths());
|
||||
|
||||
// do convolution
|
||||
threadwise_direct_convolution_1(
|
||||
@@ -159,14 +159,14 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
p_in + in_desc.Get1dIndex(0, 0, s, 0),
|
||||
in_reg_desc,
|
||||
p_in_reg,
|
||||
in_reg_desc);
|
||||
in_reg_desc.GetLengths());
|
||||
|
||||
// read first 1x1 weight
|
||||
threadwise_4d_tensor_copy(wei_desc,
|
||||
p_wei + wei_desc.Get1dIndex(0, 0, s, 0),
|
||||
wei_reg_desc,
|
||||
p_wei_reg,
|
||||
wei_reg_desc);
|
||||
wei_reg_desc.GetLengths());
|
||||
|
||||
// do first 1x1 conv
|
||||
threadwise_direct_convolution_1(
|
||||
@@ -180,7 +180,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
p_wei + wei_desc.Get1dIndex(0, 0, s, r),
|
||||
wei_reg_desc,
|
||||
p_wei_reg,
|
||||
wei_reg_desc);
|
||||
wei_reg_desc.GetLengths());
|
||||
|
||||
// shift old input to the left
|
||||
threadwise_4d_tensor_shift_down(in_reg_desc, p_in_reg, I3, Number<in_w_new_read>{});
|
||||
@@ -192,7 +192,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
in_reg_desc,
|
||||
p_in_reg +
|
||||
in_reg_desc.Get1dIndex(0, 0, 0, in_reg_desc.GetLength(I3) - in_w_new_read),
|
||||
in_desc_reg_new_read);
|
||||
in_desc_reg_new_read.GetLengths());
|
||||
|
||||
// do 1x1 conv
|
||||
threadwise_direct_convolution_1(
|
||||
@@ -211,11 +211,14 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
p_wei + wei_desc.Get1dIndex(0, 0, s, r),
|
||||
wei_reg_desc,
|
||||
p_wei_reg,
|
||||
wei_reg_desc);
|
||||
wei_reg_desc.GetLengths());
|
||||
|
||||
// read new input
|
||||
threadwise_4d_tensor_copy(
|
||||
in_desc, p_in + in_desc.Get1dIndex(0, 0, s, r), in_reg_desc, p_in_reg, in_reg_desc);
|
||||
threadwise_4d_tensor_copy(in_desc,
|
||||
p_in + in_desc.Get1dIndex(0, 0, s, r),
|
||||
in_reg_desc,
|
||||
p_in_reg,
|
||||
in_reg_desc.GetLengths());
|
||||
|
||||
// do 1x1 conv
|
||||
threadwise_direct_convolution_1(
|
||||
|
||||
Reference in New Issue
Block a user