diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp index 38e9360016..433ba2d855 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -22,6 +22,8 @@ template {}) - .Slice(I3, Number{}) - .Fold(I0, Number{}, Number{}) - .Extract(Sequence<0, 1, 2, 4, 5>{}); + constexpr auto in_n0_n1_n2_h_w_global_desc = + in_n_c_h_w_global_desc.StridedSlice(I2, Number{}, Number{}) + .StridedSlice(I3, Number{}, Number{}) + .Fold(I0, Number{}, Number{}) + .Extract(Sequence<0, 1, 2, 4, 5>{}); // batch descritpor for device memory - constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number{}) - .Slice(I3, Number{}) - .Extract(Sequence<1, 2, 3>{}); + constexpr auto in_c_y_x_global_desc = + in_n_c_h_w_global_desc.StridedSlice(I2, Number{}, Number{}) + .StridedSlice(I3, Number{}, Number{}) + .Extract(Sequence<1, 2, 3>{}); // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( diff --git a/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp b/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp index 76c3761d10..f93a5a60cd 100644 --- a/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp +++ b/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp @@ -320,6 +320,18 @@ struct ConstantTensorDescriptor return ConstantTensorDescriptor{}; } + template + __host__ __device__ static constexpr auto + StridedSlice(Number, Number, Number) + { + constexpr index_t new_stride = Strides::Get(Number{}) * SliceStride; + + using new_lengths = decltype(Lengths::Modify(Number{}, Number{})); + using new_strides = decltype(Strides::Modify(Number{}, Number{})); + + return ConstantTensorDescriptor{}; + } + template __host__ __device__ static constexpr auto Fold(Number, Number...) { diff --git a/driver/include/conv_common.hpp b/driver/include/conv_common.hpp index 254f4c5651..d1ddb42317 100644 --- a/driver/include/conv_common.hpp +++ b/driver/include/conv_common.hpp @@ -36,11 +36,14 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(InDesc, WeiDe return make_ConstantTensorDescriptor_packed(Sequence{}); } -template -constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor(InDesc, - WeiDesc, - LowerPads, - UpperPads) +template +constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor( + InDesc, WeiDesc, ConvStrides, ConvDilations, LowerPads, UpperPads) { constexpr auto in_desc = InDesc{}; constexpr auto wei_desc = WeiDesc{}; @@ -55,24 +58,27 @@ constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor( static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1), "input & weight dimension not consistent"); - constexpr auto N = in_desc.GetLength(I0); - constexpr auto HI = in_desc.GetLength(I2); - constexpr auto WI = in_desc.GetLength(I3); + constexpr index_t N = in_desc.GetLength(I0); + constexpr index_t Hi = in_desc.GetLength(I2); + constexpr index_t Wi = in_desc.GetLength(I3); - constexpr auto K = wei_desc.GetLength(I0); - constexpr auto Y = wei_desc.GetLength(I2); - constexpr auto X = wei_desc.GetLength(I3); + constexpr index_t K = wei_desc.GetLength(I0); + constexpr index_t Y = wei_desc.GetLength(I2); + constexpr index_t X = wei_desc.GetLength(I3); - constexpr auto HPadLow = LowerPads{}.Get(I0); - constexpr auto WPadLow = LowerPads{}.Get(I1); + constexpr index_t HPadLow = LowerPads{}.Get(I0); + constexpr index_t WPadLow = LowerPads{}.Get(I1); - constexpr auto HPadUp = UpperPads{}.Get(I0); - constexpr auto WPadUp = UpperPads{}.Get(I1); + constexpr index_t HPadUp = UpperPads{}.Get(I0); + constexpr index_t WPadUp = UpperPads{}.Get(I1); - constexpr auto HO = HI + HPadLow + HPadUp + 1 - Y; - constexpr auto WO = WI + WPadLow + WPadUp + 1 - X; + constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1; + constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1; - return make_ConstantTensorDescriptor_packed(Sequence{}); + constexpr index_t Ho = (Hi + HPadLow + HPadUp - YEff) / ConvStrides{}[0] + 1; + constexpr index_t Wo = (Wi + WPadLow + WPadUp - XEff) / ConvStrides{}[1] + 1; + + return make_ConstantTensorDescriptor_packed(Sequence{}); } template diff --git a/driver/include/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp index 80a6155271..3d12acd24a 100644 --- a/driver/include/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp @@ -8,13 +8,20 @@ using namespace ck; -template +template void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, const Tensor& in_nchw, WeiDesc, const Tensor& wei_kcyx, OutDesc, Tensor& out_nkhw, + ConvStrides, + ConvDilations, index_t nrepeat) { constexpr auto I0 = Number<0>{}; @@ -107,6 +114,8 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, decltype(in_nchw_desc), decltype(wei_kcyx_desc), decltype(out_nkhw_desc), + ConvStrides, + ConvDilations, BPerBlock, KPerBlock, CPerBlock, diff --git a/driver/src/driver.cpp b/driver/src/driver.cpp index b930734c00..bfca88bc53 100644 --- a/driver/src/driver.cpp +++ b/driver/src/driver.cpp @@ -103,10 +103,18 @@ auto make_TensorDescriptor(TConstTensorDesc) return TensorDescriptor(lengths, strides); } -template +template void host_direct_convolution(const Tensor& in_nchw, const Tensor& wei_kcyx, Tensor& out_nkhw, + ConvStrides, + ConvDilations, LowerPads, UpperPads) { @@ -122,10 +130,10 @@ void host_direct_convolution(const Tensor& in_nchw, { for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y) { - int hi = ho + y - h_pad_low; + int hi = ho * ConvStrides{}[0] + y * ConvDilations{}[0] - h_pad_low; for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x) { - int wi = wo + x - w_pad_low; + int wi = wo * ConvStrides{}[1] + x * ConvDilations{}[1] - w_pad_low; if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 && wi < in_nchw.mDesc.GetLengths()[3]) { @@ -419,9 +427,9 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 1 +#elif 0 // 3x3, 34x34 - constexpr index_t N = 64; + constexpr index_t N = 128; constexpr index_t C = 256; constexpr index_t HI = 34; constexpr index_t WI = 34; @@ -429,6 +437,9 @@ int main(int argc, char* argv[]) constexpr index_t Y = 3; constexpr index_t X = 3; + using ConvStrides = Sequence<2, 2>; + using ConvDilations = Sequence<1, 1>; + constexpr index_t HPad = 0; constexpr index_t WPad = 0; #elif 0 @@ -453,6 +464,9 @@ int main(int argc, char* argv[]) constexpr index_t Y = 3; constexpr index_t X = 3; + using ConvStrides = Sequence<2, 2>; + using ConvDilations = Sequence<1, 1>; + constexpr index_t HPad = 0; constexpr index_t WPad = 0; #elif 0 @@ -583,7 +597,7 @@ int main(int argc, char* argv[]) auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence{}); auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence{}); auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor( - in_nchw_desc, wei_kcyx_desc, lower_pads, upper_pads); + in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, lower_pads, upper_pads); ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); @@ -645,9 +659,17 @@ int main(int argc, char* argv[]) #elif 1 device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw #endif - (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); + (in_nchw_desc, + in_nchw, + wei_kcyx_desc, + wei_kcyx, + out_nkhw_desc, + out_nkhw_device, + ConvStrides{}, + ConvDilations{}, + nrepeat); -#elif 1 +#elif 0 device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(in_nchw_desc, in_nchw, wei_kcyx_desc, @@ -662,14 +684,21 @@ int main(int argc, char* argv[]) if(do_verification) { #if 1 - if(Y == 3 && X == 3) + if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 && + ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1) { host_winograd_3x3_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads); } else #endif { - host_direct_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads); + host_direct_convolution(in_nchw, + wei_kcyx, + out_nkhw_host, + ConvStrides{}, + ConvDilations{}, + lower_pads, + upper_pads); } check_error(out_nkhw_host, out_nkhw_device);