mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
@@ -79,11 +79,43 @@ __device__ void threadwise_direct_convolution_1(InDesc,
|
||||
}
|
||||
}
|
||||
|
||||
// Optimized for scenario if p_in and p_wei are in LDS, p_out are in register
|
||||
// Copy in and wei into register before doing convolution
|
||||
template <class TFloat, class InDesc, class WeiDesc, class OutDesc>
|
||||
__device__ void threadwise_direct_convolution_2(InDesc,
|
||||
TFloat* const __restrict__ p_in,
|
||||
WeiDesc,
|
||||
TFloat* const __restrict__ p_wei,
|
||||
OutDesc,
|
||||
TFloat* __restrict__ p_out)
|
||||
{
|
||||
constexpr auto in_desc = InDesc{};
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
constexpr auto out_desc = OutDesc{};
|
||||
|
||||
constexpr auto in_reg_desc = make_ConstantTensorDescriptor(in_desc.GetLengths());
|
||||
constexpr auto wei_reg_desc = make_ConstantTensorDescriptor(wei_desc.GetLengths());
|
||||
|
||||
// register
|
||||
TFloat p_in_reg[in_reg_desc.GetElementSpace()];
|
||||
TFloat p_wei_reg[wei_reg_desc.GetElementSpace()];
|
||||
|
||||
// copy input tensor into register
|
||||
threadwise_4d_tensor_copy(in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc);
|
||||
|
||||
// copy input tensor into register
|
||||
threadwise_4d_tensor_copy(wei_desc, p_wei, wei_reg_desc, p_wei_reg, wei_reg_desc);
|
||||
|
||||
// do convolution
|
||||
threadwise_direct_convolution_1(
|
||||
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
|
||||
}
|
||||
|
||||
// optimized for scenario where p_in and p_wei are in LDS, p_out is in register
|
||||
// break down a non-1x1 convolution into a sequence of 1x1 convolutions,
|
||||
// load 1x1 weight into register, and do 1x1 convolution in register.
|
||||
template <class TFloat, class InDesc, class WeiDesc, class OutDesc>
|
||||
__device__ void threadwise_direct_convolution_2(InDesc,
|
||||
__device__ void threadwise_direct_convolution_3(InDesc,
|
||||
TFloat* const __restrict__ p_in,
|
||||
WeiDesc,
|
||||
TFloat* const __restrict__ p_wei,
|
||||
@@ -95,100 +127,100 @@ __device__ void threadwise_direct_convolution_2(InDesc,
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_desc_lds = InDesc{};
|
||||
constexpr auto wei_desc_lds = WeiDesc{};
|
||||
constexpr auto out_desc_reg = OutDesc{};
|
||||
constexpr auto in_desc = InDesc{};
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
constexpr auto out_desc = OutDesc{};
|
||||
|
||||
constexpr auto in_desc_reg =
|
||||
make_ConstantTensorDescriptor(Sequence<in_desc_lds.GetLength(I0),
|
||||
in_desc_lds.GetLength(I1),
|
||||
out_desc_reg.GetLength(I2),
|
||||
out_desc_reg.GetLength(I3)>{});
|
||||
constexpr auto in_reg_desc = make_ConstantTensorDescriptor(Sequence<in_desc.GetLength(I0),
|
||||
in_desc.GetLength(I1),
|
||||
out_desc.GetLength(I2),
|
||||
out_desc.GetLength(I3)>{});
|
||||
|
||||
constexpr auto wei_desc_reg = make_ConstantTensorDescriptor(
|
||||
Sequence<wei_desc_lds.GetLength(I0), wei_desc_lds.GetLength(I1), 1, 1>{});
|
||||
constexpr auto wei_reg_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<wei_desc.GetLength(I0), wei_desc.GetLength(I1), 1, 1>{});
|
||||
|
||||
TFloat p_in_reg[in_desc_reg.GetElementSpace()];
|
||||
TFloat p_wei_reg[wei_desc_reg.GetElementSpace()];
|
||||
TFloat p_in_reg[in_reg_desc.GetElementSpace()];
|
||||
TFloat p_wei_reg[wei_reg_desc.GetElementSpace()];
|
||||
|
||||
constexpr unsigned in_w_new_read = 1;
|
||||
|
||||
constexpr auto in_desc_reg_new_read =
|
||||
make_ConstantTensorDescriptor(Sequence<in_desc_reg.GetLength(I0),
|
||||
in_desc_reg.GetLength(I1),
|
||||
in_desc_reg.GetLength(I2),
|
||||
make_ConstantTensorDescriptor(Sequence<in_reg_desc.GetLength(I0),
|
||||
in_reg_desc.GetLength(I1),
|
||||
in_reg_desc.GetLength(I2),
|
||||
in_w_new_read>{});
|
||||
|
||||
#if 0
|
||||
// loop over vertical direction
|
||||
for(unsigned s = 0; s < wei_desc_lds.GetLength(I2); ++s)
|
||||
for(unsigned s = 0; s < wei_desc.GetLength(I2); ++s)
|
||||
{
|
||||
#if 1
|
||||
// read first input
|
||||
threadwise_4d_tensor_copy(in_desc_lds,
|
||||
p_in + in_desc_lds.Get1dIndex(0, 0, s, 0),
|
||||
in_desc_reg,
|
||||
threadwise_4d_tensor_copy(in_desc,
|
||||
p_in + in_desc.Get1dIndex(0, 0, s, 0),
|
||||
in_reg_desc,
|
||||
p_in_reg,
|
||||
in_desc_reg);
|
||||
in_reg_desc);
|
||||
|
||||
// read first 1x1 weight
|
||||
threadwise_4d_tensor_copy(wei_desc_lds,
|
||||
p_wei + wei_desc_lds.Get1dIndex(0, 0, s, 0),
|
||||
wei_desc_reg,
|
||||
threadwise_4d_tensor_copy(wei_desc,
|
||||
p_wei + wei_desc.Get1dIndex(0, 0, s, 0),
|
||||
wei_reg_desc,
|
||||
p_wei_reg,
|
||||
wei_desc_reg);
|
||||
wei_reg_desc);
|
||||
|
||||
// do first 1x1 conv
|
||||
threadwise_direct_convolution_1(
|
||||
in_desc_reg, p_in_reg, wei_desc_reg, p_wei_reg, out_desc_reg, p_out);
|
||||
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
|
||||
|
||||
// loop over horizontal direction
|
||||
for(unsigned r = 1; r < wei_desc_lds.GetLength(I3); ++r)
|
||||
for(unsigned r = 1; r < wei_desc.GetLength(I3); ++r)
|
||||
{
|
||||
// read new weight
|
||||
threadwise_4d_tensor_copy(wei_desc_lds,
|
||||
p_wei + wei_desc_lds.Get1dIndex(0, 0, s, r),
|
||||
wei_desc_reg,
|
||||
threadwise_4d_tensor_copy(wei_desc,
|
||||
p_wei + wei_desc.Get1dIndex(0, 0, s, r),
|
||||
wei_reg_desc,
|
||||
p_wei_reg,
|
||||
wei_desc_reg);
|
||||
wei_reg_desc);
|
||||
|
||||
// shift old input to the left
|
||||
threadwise_4d_tensor_shift_down(in_desc_reg, p_in_reg, I3, Number<in_w_new_read>{});
|
||||
threadwise_4d_tensor_shift_down(in_reg_desc, p_in_reg, I3, Number<in_w_new_read>{});
|
||||
|
||||
// read new input
|
||||
threadwise_4d_tensor_copy(
|
||||
in_desc_lds,
|
||||
p_in + in_desc_lds.Get1dIndex(0, 0, s, in_desc_reg.GetLength(I3) + r - 1),
|
||||
in_desc_reg,
|
||||
in_desc,
|
||||
p_in + in_desc.Get1dIndex(0, 0, s, r + in_reg_desc.GetLength(I3) - 1),
|
||||
in_reg_desc,
|
||||
p_in_reg +
|
||||
in_desc_reg.Get1dIndex(0, 0, 0, in_desc_reg.GetLength(I3) - in_w_new_read),
|
||||
in_reg_desc.Get1dIndex(0, 0, 0, in_reg_desc.GetLength(I3) - in_w_new_read),
|
||||
in_desc_reg_new_read);
|
||||
|
||||
// do 1x1 conv
|
||||
threadwise_direct_convolution_1(
|
||||
in_desc_reg, p_in_reg, wei_desc_reg, p_wei_reg, out_desc_reg, p_out);
|
||||
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
|
||||
}
|
||||
}
|
||||
#elif 1
|
||||
// loop over vertical direction
|
||||
for(unsigned s = 0; s < wei_desc.GetLength(I2); ++s)
|
||||
{
|
||||
// loop over horizontal direction
|
||||
for(unsigned r = 0; r < wei_desc_lds.GetLength(I3); ++r)
|
||||
for(unsigned r = 0; r < wei_desc.GetLength(I3); ++r)
|
||||
{
|
||||
// read new weight
|
||||
threadwise_4d_tensor_copy(wei_desc_lds,
|
||||
p_wei + wei_desc_lds.Get1dIndex(0, 0, s, r),
|
||||
wei_desc_reg,
|
||||
threadwise_4d_tensor_copy(wei_desc,
|
||||
p_wei + wei_desc.Get1dIndex(0, 0, s, r),
|
||||
wei_reg_desc,
|
||||
p_wei_reg,
|
||||
wei_desc_reg);
|
||||
wei_reg_desc);
|
||||
|
||||
// read new input
|
||||
threadwise_4d_tensor_copy(in_desc_lds,
|
||||
p_in + in_desc_lds.Get1dIndex(0, 0, s, r),
|
||||
in_desc_reg,
|
||||
p_in_reg,
|
||||
in_desc_reg);
|
||||
threadwise_4d_tensor_copy(
|
||||
in_desc, p_in + in_desc.Get1dIndex(0, 0, s, r), in_reg_desc, p_in_reg, in_reg_desc);
|
||||
|
||||
// do 1x1 conv
|
||||
threadwise_direct_convolution_1(
|
||||
in_desc_reg, p_in_reg, wei_desc_reg, p_wei_reg, out_desc_reg, p_out);
|
||||
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
Reference in New Issue
Block a user