diff --git a/include/ck/tensor_operation/gpu/device/helper.hpp b/include/ck/tensor_operation/gpu/device/helper.hpp index c52566509f..c0e5ce9dd3 100644 --- a/include/ck/tensor_operation/gpu/device/helper.hpp +++ b/include/ck/tensor_operation/gpu/device/helper.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include "ck/utility/common_header.hpp" @@ -95,16 +98,27 @@ auto transform_conv(ck::index_t num_dim, ck::Array out_lengths, ck::Array out_strides) { + ck::Array dummy_dims; + ck::Array dummy_spatial_dims; if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) { ck::tensor_operation::TransformConvFwdToGemm< 2, ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> - conv_fwd; + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; auto res = ck::tensor_operation::TransformConv(); - return res.transform_func(out_lengths, out_strides, conv_fwd); + return res.transform_func(conv_fwd); } if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) @@ -112,10 +126,19 @@ auto transform_conv(ck::index_t num_dim, ck::tensor_operation::TransformConvFwdToGemm< 2, ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> - conv_fwd; + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; auto res = ck::tensor_operation::TransformConv(); - return res.transform_func(out_lengths, out_strides, conv_fwd); + return res.transform_func(conv_fwd); } if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) @@ -123,20 +146,38 @@ auto transform_conv(ck::index_t num_dim, ck::tensor_operation::TransformConvFwdToGemm< 2, ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> - conv_fwd; + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; auto res = ck::tensor_operation::TransformConv(); - return res.transform_func(out_lengths, out_strides, conv_fwd); + return res.transform_func(conv_fwd); } if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) { ck::tensor_operation::TransformConvFwdToGemm< 2, ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> - conv_fwd; + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; auto res = ck::tensor_operation::TransformConv(); - return res.transform_func(out_lengths, out_strides, conv_fwd); + return res.transform_func(conv_fwd); } throw std::runtime_error("Incorrect conv spec"); } @@ -146,16 +187,28 @@ auto transform_conv_3d(ck::index_t num_dim, ck::Array out_lengths, ck::Array out_strides) { + ck::Array dummy_dims; + ck::Array dummy_spatial_dims; + if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) { ck::tensor_operation::TransformConvFwdToGemm< 3, ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> - conv_fwd; + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; auto res = ck::tensor_operation::TransformConv(); - return res.transform_func(out_lengths, out_strides, conv_fwd); + return res.transform_func(conv_fwd); } if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) @@ -163,10 +216,19 @@ auto transform_conv_3d(ck::index_t num_dim, ck::tensor_operation::TransformConvFwdToGemm< 3, ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> - conv_fwd; + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; auto res = ck::tensor_operation::TransformConv(); - return res.transform_func(out_lengths, out_strides, conv_fwd); + return res.transform_func(conv_fwd); } if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) @@ -174,20 +236,38 @@ auto transform_conv_3d(ck::index_t num_dim, ck::tensor_operation::TransformConvFwdToGemm< 3, ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> - conv_fwd; + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; auto res = ck::tensor_operation::TransformConv(); - return res.transform_func(out_lengths, out_strides, conv_fwd); + return res.transform_func(conv_fwd); } if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) { ck::tensor_operation::TransformConvFwdToGemm< 3, ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> - conv_fwd; + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; auto res = ck::tensor_operation::TransformConv(); - return res.transform_func(out_lengths, out_strides, conv_fwd); + return res.transform_func(conv_fwd); } throw std::runtime_error("Incorrect conv spec"); } @@ -197,16 +277,28 @@ auto transform_conv_1d(ck::index_t num_dim, ck::Array out_lengths, ck::Array out_strides) { + ck::Array dummy_dims; + ck::Array dummy_spatial_dims; + if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) { ck::tensor_operation::TransformConvFwdToGemm< 1, ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> - conv_fwd; + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; auto res = ck::tensor_operation::TransformConv(); - return res.transform_func(out_lengths, out_strides, conv_fwd); + return res.transform_func(conv_fwd); } if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) @@ -214,10 +306,19 @@ auto transform_conv_1d(ck::index_t num_dim, ck::tensor_operation::TransformConvFwdToGemm< 1, ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> - conv_fwd; + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; auto res = ck::tensor_operation::TransformConv(); - return res.transform_func(out_lengths, out_strides, conv_fwd); + return res.transform_func(conv_fwd); } if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) @@ -225,20 +326,38 @@ auto transform_conv_1d(ck::index_t num_dim, ck::tensor_operation::TransformConvFwdToGemm< 1, ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> - conv_fwd; + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; auto res = ck::tensor_operation::TransformConv(); - return res.transform_func(out_lengths, out_strides, conv_fwd); + return res.transform_func(conv_fwd); } if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) { ck::tensor_operation::TransformConvFwdToGemm< 1, ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> - conv_fwd; + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; auto res = ck::tensor_operation::TransformConv(); - return res.transform_func(out_lengths, out_strides, conv_fwd); + return res.transform_func(conv_fwd); } throw std::runtime_error("Incorrect dims or conv spec"); } diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 7ef4e7f184..e36edc3cfe 100644 --- a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -359,36 +359,17 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - static constexpr auto conv_to_gemm_transformer = - TransformConvFwdToGemm{}; + using GemmToConvFwdTransformer = TransformConvFwdToGemm; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; template __host__ __device__ static auto - MakeAGridDescriptor_M_K(const ck::Array& a_g_n_c_wis_lengths, - const ck::Array& a_g_n_c_wis_strides, - const ck::Array& b_g_k_c_xs_lengths, - const ck::Array& b_g_k_c_xs_strides, - const ck::Array& e_g_n_k_wos_lengths, - const ck::Array& e_g_n_k_wos_strides, - const ck::Array& conv_filter_strides, - const ck::Array& conv_filter_dilations, - const ck::Array& input_left_pads, - const ck::Array& input_right_pads) + MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto in_gemmmraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads); + conv_to_gemm_transformer.template MakeADescriptor_M_K(); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -398,12 +379,10 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle template __host__ __device__ static auto - MakeBGridDescriptor_N_K(const ck::Array& b_g_k_c_xs_lengths, - const ck::Array& b_g_k_c_xs_strides) + MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto wei_gemmnraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides); + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); const auto wei_gemmn_gemmk_desc = matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); @@ -413,12 +392,10 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle template __host__ __device__ static auto - MakeEGridDescriptor_M_N(const ck::Array& e_g_n_k_wos_lengths, - const ck::Array& e_g_n_k_wos_strides) + MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); @@ -428,26 +405,27 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // Shape of Ds and E must be aligned. Strides can be different. // Pass e_g_n_k_wos_lengths for logical broadcast. - __host__ __device__ static auto MakeDsGridDescriptor_M_N( - const ck::Array& e_g_n_k_wos_lengths, - const ck::Array, NumDTensor>& ds_g_n_k_wos_strides) + static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { return generate_tuple( [&](auto i) { using DLayout = remove_cvref_t>; - return DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, - ds_g_n_k_wos_strides[i]); + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer); }, Number{}); } // desc for problem definition - using AGridDesc_M_K = remove_cvref_t( - {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; - using BGridDesc_N_K = remove_cvref_t({}, {}))>; - using DsGridDesc_M_N = remove_cvref_t; - using EGridDesc_M_N = remove_cvref_t({}, {}))>; + constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer; + using AGridDesc_M_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using BGridDesc_N_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using DsGridDesc_M_N = + remove_cvref_t; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; // If we are using multiAB and one of the template datatype parameters is not a tuple, convert // it to it @@ -533,21 +511,23 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, num_group_{a_g_n_c_wis_lengths[0]}, - a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads)}, - b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides)}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, + a_grid_desc_m_k_{ + DeviceOp::MakeAGridDescriptor_M_K(conv_to_gemm_transformer_)}, + b_grid_desc_n_k_{ + DeviceOp::MakeBGridDescriptor_N_K(conv_to_gemm_transformer_)}, ds_grid_desc_m_n_{}, - e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides)}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, a_grid_desc_ak0_m_ak1_{ GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ @@ -637,9 +617,20 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // D batch stride compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + ds_g_n_k_wos_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + // D desc - ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( - e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]); + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); }); compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; @@ -694,6 +685,9 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // tensor descriptors for problem definiton index_t num_group_; + + GemmToConvFwdTransformer conv_to_gemm_transformer_; + AGridDesc_M_K a_grid_desc_m_k_; BGridDesc_N_K b_grid_desc_n_k_; DsGridDesc_M_N ds_grid_desc_m_n_; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp index a7a366ffbc..e7ac9c0314 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp @@ -8,7 +8,6 @@ #include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp" #include "ck/host_utility/kernel_launch.hpp" -#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" @@ -65,8 +64,8 @@ struct DeviceColumnToImageImpl static constexpr auto spatial_offset = Number<3>{}; - static constexpr auto conv_to_gemm_transformer = - TransformConvFwdToGemm{}; + using GemmToConvFwdTransformer = + TransformConvFwdToGemm; static constexpr auto matrix_padder = MatrixPadder{ MPerBlock, 0 /* NPerBlock*/, KPerBlock}; @@ -234,21 +233,21 @@ struct DeviceColumnToImageImpl : independent_filter_stride; } + GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths, + image_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + {}, // not needed for A Descriptor + c_g_n_k_wos_lengths, + {}, // not needed for A Descriptor + // conv_filter_strides, + independent_filter_strides, + conv_filter_dilations, + input_left_pads_with_offset, + input_right_pads}; + // Calculate image form descriptor for the modified convolution problem const auto in_gemmmraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeADescriptor_M_K( - a_g_n_c_wis_lengths, - image_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - {}, // not needed for A Descriptor - c_g_n_k_wos_lengths, - {}, // not needed for A Descriptor - // conv_filter_strides, - independent_filter_strides, - conv_filter_dilations, - input_left_pads_with_offset, - input_right_pads, - N); + conv_to_gemm_transformer.template MakeADescriptor_M_K(); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index c3fe54b075..6ee086d09e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -238,37 +238,17 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - static constexpr auto conv_to_gemm_transformer = - TransformConvFwdToGemm{}; + using GemmToConvFwdTransformer = TransformConvFwdToGemm; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, K0PerBlock}; template static auto - MakeAGridDescriptor_AK0_M_AK1(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) + MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto in_gemmmraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - a_g_n_c_wis_lengths[I1]); + conv_to_gemm_transformer.template MakeADescriptor_M_K(); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -286,12 +266,10 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK template static auto - MakeBGridDescriptor_BK0_N_BK1(const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides) + MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto wei_gemmnraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides); + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); const auto wei_gemmn_gemmk_desc = matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); @@ -309,13 +287,10 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK } template - static auto - MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides) + static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N( - e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]); + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); @@ -323,27 +298,27 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK return out_gemmm_gemmn_desc; } - static auto MakeDsGridDescriptor_M_N( - const std::array, NumDTensor>& ds_g_n_k_wos_lengths, - const std::array, NumDTensor>& ds_g_n_k_wos_strides) + static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { return generate_tuple( [&](auto i) { using DLayout = remove_cvref_t>; - return DeviceOp::MakeEGridDescriptor_M_N(ds_g_n_k_wos_lengths[i], - ds_g_n_k_wos_strides[i]); + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer); }, Number{}); } // desc for problem definition + constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer; using AGridDesc_AK0_M_AK1 = remove_cvref_t( - {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; - using BGridDesc_BK0_N_BK1 = - remove_cvref_t({}, {}))>; - using DsGridDesc_M_N = remove_cvref_t; - using EGridDesc_M_N = remove_cvref_t({}, {}))>; + dummy_conv_to_gemm_transformer))>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t( + dummy_conv_to_gemm_transformer))>; + using DsGridDesc_M_N = + remove_cvref_t; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; // GridwiseGemm using GridwiseGemm = @@ -426,21 +401,22 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, num_group_{a_g_n_c_wis_lengths[0]}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, a_grid_desc_ak0_m_ak1_{ - DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads)}, - b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1( - b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, - e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides)}, + DeviceOp::MakeAGridDescriptor_AK0_M_AK1(conv_to_gemm_transformer_)}, + b_grid_desc_bk0_n_bk1_{ + DeviceOp::MakeBGridDescriptor_BK0_N_BK1(conv_to_gemm_transformer_)}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, a_grid_desc_k0_m0_m1_k1_{}, b_grid_desc_k0_n0_n1_k1_{}, ds_grid_desc_m0_m10_m11_n0_n10_n11_{}, @@ -471,6 +447,17 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK using DLayout = remove_cvref_t>; using DDataType = remove_cvref_t>; + GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths[i], + ds_g_n_k_wos_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + // D pointer p_ds_grid_(i) = static_cast(p_ds[i]); @@ -478,8 +465,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; // D desc - ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( - ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]); + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); }); // populate desc for Ds/E @@ -523,6 +510,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK // tensor descriptors for problem definiton index_t num_group_; + + GemmToConvFwdTransformer conv_to_gemm_transformer_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; DsGridDesc_M_N ds_grid_desc_m_n_; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp index c6b84b613c..0a58cd0c88 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -234,37 +234,17 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd{}; static constexpr auto I3 = Number<3>{}; - static constexpr auto conv_to_gemm_transformer = - TransformConvFwdToGemm{}; + using GemmToConvFwdTransformer = TransformConvFwdToGemm; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, K0PerBlock}; template static auto - MakeAGridDescriptor_AK0_M_AK1(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array& c_g_n_k_wos_lengths, - const std::array& c_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) + MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto in_gemmmraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - c_g_n_k_wos_lengths, - c_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - a_g_n_c_wis_lengths[I1]); + conv_to_gemm_transformer.template MakeADescriptor_M_K(); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -283,12 +263,10 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd static auto - MakeBGridDescriptor_BK0_N_BK1(const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides) + MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto wei_gemmnraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides); + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); const auto wei_gemmn_gemmk_desc = matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); @@ -306,13 +284,10 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd - static auto - MakeCGridDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& c_g_n_k_wos_strides) + static auto MakeCGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N( - c_g_n_k_wos_lengths, c_g_n_k_wos_strides, c_g_n_k_wos_lengths[I1]); + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); @@ -321,11 +296,13 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd( - {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; - using BGridDesc_BK0_N_BK1 = - remove_cvref_t({}, {}))>; - using CGridDesc_M_N = remove_cvref_t({}, {}))>; + dummy_conv_to_gemm_transformer))>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t( + dummy_conv_to_gemm_transformer))>; + using CGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; // GridwiseGemm using GridwiseGemm = @@ -396,21 +373,22 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd(p_b)}, p_c_grid_{static_cast(p_c)}, num_group_{a_g_n_c_wis_lengths[0]}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, a_grid_desc_ak0_m_ak1_{ - DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - c_g_n_k_wos_lengths, - c_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads)}, - b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1( - b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, - c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(c_g_n_k_wos_lengths, - c_g_n_k_wos_strides)}, + DeviceOp::MakeAGridDescriptor_AK0_M_AK1(conv_to_gemm_transformer_)}, + b_grid_desc_bk0_n_bk1_{ + DeviceOp::MakeBGridDescriptor_BK0_N_BK1(conv_to_gemm_transformer_)}, + c_grid_desc_m_n_{ + DeviceOp::MakeCGridDescriptor_M_N(conv_to_gemm_transformer_)}, a_grid_desc_k0_m0_m1_k1_{}, b_grid_desc_k0_n0_n1_k1_{}, c_grid_desc_m0_m10_m11_n0_n10_n11_{}, @@ -473,6 +451,9 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd{}; static constexpr auto I3 = Number<3>{}; - static constexpr auto conv_to_gemm_transformer = - TransformConvFwdToGemm{}; + using GemmToConvFwdTransformer = TransformConvFwdToGemm; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; template - static auto - MakeAGridDescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const index_t Conv_N) + static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto in_gemmmraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - Conv_N); + conv_to_gemm_transformer.template MakeADescriptor_M_K(); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -356,13 +339,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } template - static auto - MakeBGridDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides) + static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto wei_gemmnraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides); + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); const auto wei_gemmn_gemmk_desc = matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); @@ -371,14 +351,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } template - static auto - MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const index_t Conv_N) + static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N( - e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N); + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); @@ -388,27 +364,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // Shape of Ds and E must be aligned. Strides can be different. // Pass e_g_n_k_wos_lengths for logical broadcast. - static auto MakeDsGridDescriptor_M_N( - const std::array& e_g_n_k_wos_lengths, - const std::array, NumDTensor>& ds_g_n_k_wos_strides, - const index_t Conv_N) + static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { return generate_tuple( [&](auto i) { using DLayout = remove_cvref_t>; - return DeviceOp::MakeEGridDescriptor_M_N( - e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], Conv_N); + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer); }, Number{}); } // desc for problem definition - using AGridDesc_M_K = remove_cvref_t( - {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>; - using BGridDesc_N_K = remove_cvref_t({}, {}))>; - using DsGridDesc_M_N = remove_cvref_t; - using EGridDesc_M_N = remove_cvref_t({}, {}, 1))>; + constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer; + using AGridDesc_M_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using BGridDesc_N_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using DsGridDesc_M_N = + remove_cvref_t; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; // If we are using multiAB and one of the template datatype parameters is not a tuple, convert // it to it @@ -496,28 +472,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, num_group_{a_g_n_c_wis_lengths[0]}, - conv_N_per_block_{ - conv_to_gemm_transformer.template GetSplitedNSize( - a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides)}, - a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - conv_N_per_block_)}, - b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides)}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, + conv_N_per_block_{conv_to_gemm_transformer_.N_}, + a_grid_desc_m_k_{ + DeviceOp::MakeAGridDescriptor_M_K(conv_to_gemm_transformer_)}, + b_grid_desc_n_k_{ + DeviceOp::MakeBGridDescriptor_N_K(conv_to_gemm_transformer_)}, ds_grid_desc_m_n_{}, - e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N( - e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, a_grid_desc_ak0_m_ak1_{ GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ @@ -623,9 +595,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle compute_ptr_offset_of_n_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][1] * conv_N_per_block_; + GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + ds_g_n_k_wos_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + // D desc - ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( - e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], conv_N_per_block_); + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); }); compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0] * NumGroupsToMerge; compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; @@ -690,6 +673,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // tensor descriptors for problem definiton index_t num_group_; + + GemmToConvFwdTransformer conv_to_gemm_transformer_; + index_t conv_N_per_block_; AGridDesc_M_K a_grid_desc_m_k_; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index a4d4a01a01..025123a880 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -293,39 +293,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - static constexpr auto conv_to_gemm_transformer = - TransformConvFwdToGemm{}; + using GemmToConvFwdTransformer = TransformConvFwdToGemm; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; template static auto - MakeAGridDescriptor_AK0_M_AK1(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const index_t Conv_N) + MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto in_gemmmraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - Conv_N); + conv_to_gemm_transformer.template MakeADescriptor_M_K(); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -344,12 +327,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 template static auto - MakeBGridDescriptor_BK0_N_BK1(const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides) + MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto wei_gemmnraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides); + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); const auto wei_gemmn_gemmk_desc = matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); @@ -367,15 +348,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 } template - static auto - MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const index_t Conv_N) + static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N( - e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N); + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); @@ -384,7 +361,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 } // desc for problem definition - using EGridDesc_M_N = remove_cvref_t({}, {}, 1))>; + constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; #define GridwiseGemmV3TemplateParams \ tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \ @@ -417,9 +396,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 // desc for blockwise copy using AGridDesc_AK0_M_AK1 = remove_cvref_t( - {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>; - using BGridDesc_BK0_N_BK1 = - remove_cvref_t({}, {}))>; + dummy_conv_to_gemm_transformer))>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t( + dummy_conv_to_gemm_transformer))>; using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; @@ -450,27 +429,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 p_b_grid_{}, p_e_grid_{static_cast(p_e)}, num_group_{a_g_n_c_wis_lengths[0]}, - conv_N_per_block_{ - conv_to_gemm_transformer.template GetSplitedNSize( - a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides)}, - a_grid_desc_ak0_m_ak1_{MakeAGridDescriptor_AK0_M_AK1(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - conv_N_per_block_)}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, + conv_N_per_block_{conv_to_gemm_transformer_.N_}, + a_grid_desc_ak0_m_ak1_{ + MakeAGridDescriptor_AK0_M_AK1(conv_to_gemm_transformer_)}, b_grid_desc_bk0_n_bk1_{ - MakeBGridDescriptor_BK0_N_BK1(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, - e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N( - e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)}, + MakeBGridDescriptor_BK0_N_BK1(conv_to_gemm_transformer_)}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, compute_ptr_offset_of_groups_{}, compute_ptr_offset_of_n_{}, @@ -519,6 +494,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 // tensor descriptors for problem definiton index_t num_group_; + + GemmToConvFwdTransformer conv_to_gemm_transformer_; + index_t conv_N_per_block_; // tensor descriptors for block/thread-wise copy diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp index 2170a5829a..4a8cb2d592 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -309,37 +309,16 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - static constexpr auto conv_to_gemm_transformer = - TransformConvFwdToGemm{}; + using GemmToConvFwdTransformer = TransformConvFwdToGemm; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; template - static auto - MakeAGridDescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) + static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto in_gemmmraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - a_g_n_c_wis_lengths[I1]); + conv_to_gemm_transformer.template MakeADescriptor_M_K(); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -348,13 +327,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle } template - static auto - MakeBGridDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides) + static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto wei_gemmnraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides); + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); const auto wei_gemmn_gemmk_desc = matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); @@ -363,13 +339,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle } template - static auto - MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides) + static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N( - e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]); + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); @@ -447,11 +420,14 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo); } - using AGridDesc_M_K = remove_cvref_t( - {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; - using BGridDesc_N_K = remove_cvref_t({}, {}))>; - using EGridDesc_M_N = remove_cvref_t({}, {}))>; - using RGridDesc_M = remove_cvref_t({}, {}))>; + constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer; + using AGridDesc_M_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using BGridDesc_N_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using RGridDesc_M = remove_cvref_t({}, {}))>; // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< @@ -551,21 +527,23 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, p_rs_grid_{}, // FIXME - a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads)}, - b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides)}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, + a_grid_desc_m_k_{ + DeviceOp::MakeAGridDescriptor_M_K(conv_to_gemm_transformer_)}, + b_grid_desc_n_k_{ + DeviceOp::MakeBGridDescriptor_N_K(conv_to_gemm_transformer_)}, ds_grid_desc_m_n_{}, - e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides)}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, r_grid_desc_m_{ DeviceOp::MakeRGridDescriptor_M(r_g_n_wos_lengths, r_g_n_wos_strides)}, a_grid_desc_ak0_m_ak1_{ @@ -621,9 +599,20 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle // D batch stride compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths[i], + ds_g_n_k_wos_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + // D desc - ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( - ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]); + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -660,6 +649,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle EDataType* p_e_grid_; typename GridwiseGemm::RsGridPointer p_rs_grid_; + GemmToConvFwdTransformer conv_to_gemm_transformer_; + // tensor descriptors for problem definiton AGridDesc_M_K a_grid_desc_m_k_; BGridDesc_N_K b_grid_desc_n_k_; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 9bab947fdb..981f9f421b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -135,36 +135,16 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1); - static constexpr auto conv_to_gemm_transformer = - TransformConvFwdToGemm{}; + using GemmToConvFwdTransformer = TransformConvFwdToGemm; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; template - static auto MakeAGridDescriptor(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) + static auto MakeAGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto in_gemmmraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - a_g_n_c_wis_lengths[I1]); + conv_to_gemm_transformer.template MakeADescriptor_M_K(); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -205,12 +185,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle } template - static auto MakeBGridDescriptor(const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides) + static auto MakeBGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto wei_gemmnraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides); + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); const auto wei_gemmn_gemmk_desc = matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); @@ -251,13 +229,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle } template - static auto - MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides) + static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N( - e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]); + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); @@ -265,26 +240,27 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle return out_gemmm_gemmn_desc; } - static auto MakeDsGridDescriptor_M_N( - const std::array, NumDTensor>& ds_g_n_k_wos_lengths, - const std::array, NumDTensor>& ds_g_n_k_wos_strides) + static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) { return generate_tuple( [&](auto i) { using DLayout = remove_cvref_t>; - return DeviceOp::MakeEGridDescriptor_M_N(ds_g_n_k_wos_lengths[i], - ds_g_n_k_wos_strides[i]); + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer); }, Number{}); } // desc for problem definition + constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer; using AGridDesc = - decltype(DeviceOp::MakeAGridDescriptor({}, {}, {}, {}, {}, {}, {}, {}, {}, {})); - using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor({}, {})); - using DsGridDesc_M_N = remove_cvref_t; - using EGridDesc_M_N = remove_cvref_t({}, {}))>; + decltype(DeviceOp::MakeAGridDescriptor(dummy_conv_to_gemm_transformer)); + using BGridDesc = + decltype(DeviceOp::MakeBGridDescriptor(dummy_conv_to_gemm_transformer)); + using DsGridDesc_M_N = + remove_cvref_t; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; // GridwiseOp using GridwiseOp = GridwiseGemmMultipleD_Wmma< @@ -373,21 +349,21 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, num_group_{a_g_n_c_wis_lengths[0]}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, ds_grid_desc_m_n_{}, - e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides)}, - a_grid_desc_{DeviceOp::MakeAGridDescriptor(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads)}, - b_grid_desc_{ - DeviceOp::MakeBGridDescriptor(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, + a_grid_desc_{DeviceOp::MakeAGridDescriptor(conv_to_gemm_transformer_)}, + b_grid_desc_{DeviceOp::MakeBGridDescriptor(conv_to_gemm_transformer_)}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)}, @@ -426,8 +402,24 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle }); // D desc - ds_grid_desc_m_n_ = - DeviceOp::MakeDsGridDescriptor_M_N(ds_g_n_k_wos_lengths, ds_g_n_k_wos_strides); + ds_grid_desc_m_n_ = generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths[i], + ds_g_n_k_wos_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); + }, + Number{}); // populate desc for Ds/E e_grid_desc_mblock_mperblock_nblock_nperblock_ = @@ -455,6 +447,9 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle // tensor descriptors for problem definiton index_t num_group_; + + GemmToConvFwdTransformer conv_to_gemm_transformer_; + DsGridDesc_M_N ds_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp index 9ebcb2b8c0..4828beb3a2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp @@ -57,8 +57,8 @@ struct DeviceImageToColumnImpl static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; - static constexpr auto conv_to_gemm_transformer = - TransformConvFwdToGemm{}; + using GemmToConvFwdTransformer = + TransformConvFwdToGemm; static constexpr auto matrix_padder = MatrixPadder{ @@ -97,19 +97,19 @@ struct DeviceImageToColumnImpl b_g_k_c_xs_lengths[I2] = C; c_g_n_k_wos_lengths[I1] = N; + GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths, + image_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + {}, // not needed for A Descriptor + c_g_n_k_wos_lengths, + {}, // not needed for A Descriptor + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + const auto in_gemmmraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeADescriptor_M_K( - a_g_n_c_wis_lengths, - image_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - {}, // not needed for A Descriptor - c_g_n_k_wos_lengths, - {}, // not needed for A Descriptor - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - N); + conv_to_gemm_transformer.template MakeADescriptor_M_K(); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index 8dd6573015..07cb7a7310 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -14,22 +14,15 @@ namespace ck { namespace tensor_operation { -// function to be used on device, emulates std::accumulate -template -__host__ __device__ auto mult_accumulate_n(ForwardIterator first, Size count, T init) -{ - for(ForwardIterator x = first; x != first + count; x++) - { - init *= *x; - } - return init; -} - template struct TransformConvFwdToGemm { + private: static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -37,10 +30,10 @@ struct TransformConvFwdToGemm static constexpr auto I4 = Number<4>{}; static constexpr auto I5 = Number<5>{}; - static long_index_t - calculate_element_space_size_impl(const std::array& lengths, - const std::array& strides, - index_t i) + template + static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths, + const ConvDimsType& strides, + index_t i) { long_index_t acc = 1; for(; i < (NDimSpatial + 3); i++) @@ -52,11 +45,11 @@ struct TransformConvFwdToGemm return acc; } - template - static index_t GetSplitedNSize(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& c_g_n_k_wos_lengths, - const std::array& c_g_n_k_wos_strides) + template + static index_t GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& a_g_n_c_wis_strides, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& c_g_n_k_wos_strides) { const long_index_t a_element_space_size = calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1); @@ -102,6 +95,216 @@ struct TransformConvFwdToGemm } } + public: + __host__ __device__ constexpr TransformConvFwdToGemm() {} + + template ::type = false> + __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& a_g_n_c_wis_strides, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& b_g_k_c_xs_strides, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& c_g_n_k_wos_strides, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : Di_{I1}, + Hi_{I1}, + Wi_{a_g_n_c_wis_lengths[I3]}, + Do_{I1}, + Ho_{I1}, + Wo_{c_g_n_k_wos_lengths[I3]}, + Z_{I1}, + Y_{I1}, + X_{b_g_k_c_xs_lengths[I3]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + DiStride_{I1}, + HiStride_{I1}, + WiStride_{a_g_n_c_wis_strides[I3]}, + WoStride_{c_g_n_k_wos_strides[I3]}, + XStride_{b_g_k_c_xs_strides[I3]}, + CStrideTensorA_{a_g_n_c_wis_strides[I2]}, + CStrideTensorB_{b_g_k_c_xs_strides[I2]}, + KStrideTensorB_{b_g_k_c_xs_strides[I1]}, + KStrideTensorC_{c_g_n_k_wos_strides[I2]}, + NStrideTensorA_{a_g_n_c_wis_strides[I1]}, + GStrideTensorA_{a_g_n_c_wis_strides[I0]}, + GStrideTensorB_{b_g_k_c_xs_strides[I0]}, + GStrideTensorC_{c_g_n_k_wos_strides[I0]}, + ConvStrideD_{I1}, + ConvStrideH_{I1}, + ConvStrideW_{conv_filter_strides[I0]}, + ConvDilationD_{I1}, + ConvDilationH_{I1}, + ConvDilationW_{conv_filter_dilations[I0]}, + InLeftPadD_{I0}, + InLeftPadH_{I0}, + InLeftPadW_{input_left_pads[I0]}, + InRightPadD_{I0}, + InRightPadH_{I0}, + InRightPadW_{input_right_pads[I0]}, + ZYX_{X_} + { + static_assert(is_same_v> || + is_same_v>); + static_assert(is_same_v> || + is_same_v>); + + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } + NDoHoWo_ = N_ * Wo_; + } + + template ::type = false> + __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& a_g_n_c_wis_strides, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& b_g_k_c_xs_strides, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& c_g_n_k_wos_strides, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : Di_{I1}, + Hi_{a_g_n_c_wis_lengths[I3]}, + Wi_{a_g_n_c_wis_lengths[I4]}, + Do_{I1}, + Ho_{c_g_n_k_wos_lengths[I3]}, + Wo_{c_g_n_k_wos_lengths[I4]}, + Z_{I1}, + Y_{b_g_k_c_xs_lengths[I3]}, + X_{b_g_k_c_xs_lengths[I4]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + DiStride_{I1}, + HiStride_{a_g_n_c_wis_strides[I3]}, + WiStride_{a_g_n_c_wis_strides[I4]}, + WoStride_{c_g_n_k_wos_strides[I4]}, + XStride_{b_g_k_c_xs_strides[I4]}, + CStrideTensorA_{a_g_n_c_wis_strides[I2]}, + CStrideTensorB_{b_g_k_c_xs_strides[I2]}, + KStrideTensorB_{b_g_k_c_xs_strides[I1]}, + KStrideTensorC_{c_g_n_k_wos_strides[I2]}, + NStrideTensorA_{a_g_n_c_wis_strides[I1]}, + GStrideTensorA_{a_g_n_c_wis_strides[I0]}, + GStrideTensorB_{b_g_k_c_xs_strides[I0]}, + GStrideTensorC_{c_g_n_k_wos_strides[I0]}, + ConvStrideD_{I1}, + ConvStrideH_{conv_filter_strides[I0]}, + ConvStrideW_{conv_filter_strides[I1]}, + ConvDilationD_{I1}, + ConvDilationH_{conv_filter_dilations[I0]}, + ConvDilationW_{conv_filter_dilations[I1]}, + InLeftPadD_{I0}, + InLeftPadH_{input_left_pads[I0]}, + InLeftPadW_{input_left_pads[I1]}, + InRightPadD_{I0}, + InRightPadH_{input_right_pads[I0]}, + InRightPadW_{input_right_pads[I1]}, + ZYX_{Y_ * X_} + { + static_assert(is_same_v> || + is_same_v>); + static_assert(is_same_v> || + is_same_v>); + + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } + NDoHoWo_ = N_ * Ho_ * Wo_; + } + + template ::type = false> + __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& a_g_n_c_wis_strides, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& b_g_k_c_xs_strides, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& c_g_n_k_wos_strides, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : Di_{a_g_n_c_wis_lengths[I3]}, + Hi_{a_g_n_c_wis_lengths[I4]}, + Wi_{a_g_n_c_wis_lengths[I5]}, + Do_{c_g_n_k_wos_lengths[I3]}, + Ho_{c_g_n_k_wos_lengths[I4]}, + Wo_{c_g_n_k_wos_lengths[I5]}, + Z_{b_g_k_c_xs_lengths[I3]}, + Y_{b_g_k_c_xs_lengths[I4]}, + X_{b_g_k_c_xs_lengths[I5]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + DiStride_{a_g_n_c_wis_strides[I3]}, + HiStride_{a_g_n_c_wis_strides[I4]}, + WiStride_{a_g_n_c_wis_strides[I5]}, + WoStride_{c_g_n_k_wos_strides[I5]}, + XStride_{b_g_k_c_xs_strides[I5]}, + CStrideTensorA_{a_g_n_c_wis_strides[I2]}, + CStrideTensorB_{b_g_k_c_xs_strides[I2]}, + KStrideTensorB_{b_g_k_c_xs_strides[I1]}, + KStrideTensorC_{c_g_n_k_wos_strides[I2]}, + NStrideTensorA_{a_g_n_c_wis_strides[I1]}, + GStrideTensorA_{a_g_n_c_wis_strides[I0]}, + GStrideTensorB_{b_g_k_c_xs_strides[I0]}, + GStrideTensorC_{c_g_n_k_wos_strides[I0]}, + ConvStrideD_{conv_filter_strides[I0]}, + ConvStrideH_{conv_filter_strides[I1]}, + ConvStrideW_{conv_filter_strides[I2]}, + ConvDilationD_{conv_filter_dilations[I0]}, + ConvDilationH_{conv_filter_dilations[I1]}, + ConvDilationW_{conv_filter_dilations[I2]}, + InLeftPadD_{input_left_pads[I0]}, + InLeftPadH_{input_left_pads[I1]}, + InLeftPadW_{input_left_pads[I2]}, + InRightPadD_{input_right_pads[I0]}, + InRightPadH_{input_right_pads[I1]}, + InRightPadW_{input_right_pads[I2]}, + ZYX_{Z_ * Y_ * X_} + { + static_assert(is_same_v> || + is_same_v>); + static_assert(is_same_v> || + is_same_v>); + + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } + NDoHoWo_ = N_ * Do_ * Ho_ * Wo_; + } + // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as // properties template || is_same_v), bool>::type = false> - static auto - MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const index_t N) + __host__ __device__ auto MakeADescriptor_M_K() const { - const index_t C = a_g_n_c_wis_lengths[I2]; - - const index_t Wi = a_g_n_c_wis_lengths[I3]; - - const index_t Wo = c_g_n_k_wos_lengths[I3]; - - const index_t ConvStrideW = conv_filter_strides[I0]; - - const index_t GStride = a_g_n_c_wis_strides[I0]; - const index_t NStride = a_g_n_c_wis_strides[I1]; - const auto CStride = a_g_n_c_wis_strides[I2]; - const index_t WiStride = a_g_n_c_wis_strides[I3]; - if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { - const index_t NHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - if constexpr(NumGroupsToMerge == 1) { - return make_naive_tensor_descriptor(make_tuple(NHoWo, C), - make_tuple(WiStride, CStride)); + return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_), + make_tuple(WiStride_, CStrideTensorA_)); } else { const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( - make_tuple(NHoWo, NumGroupsToMerge, C), make_tuple(WiStride, GStride, CStride)); + make_tuple(NDoHoWo_, NumGroupsToMerge, C_), + make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_)); return transform_tensor_descriptor( in_gemmm_groups_gemmk_desc, - make_tuple(make_merge_transform(make_tuple(NHoWo, NumGroupsToMerge)), - make_pass_through_transform(C)), + make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), make_tuple(Sequence<0, 1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -164,35 +340,30 @@ struct TransformConvFwdToGemm else if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter3x3) { - const index_t ConvDilationW = conv_filter_dilations[0]; - - const index_t InLeftPadW = input_left_pads[0]; - - const index_t InRightPadW = input_right_pads[0]; if constexpr(NumGroupsToMerge == 1) { - const auto in_n_wi_c_desc = - make_naive_tensor_descriptor(make_tuple(N, Wi), make_tuple(NStride, WiStride)); + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_), make_tuple(NStrideTensorA_, WiStride_)); const auto in_n_wip_c_desc = transform_tensor_descriptor( in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); const auto in_n_x_wo_c_desc = transform_tensor_descriptor( in_n_wip_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Number<3>{}, Wo), - make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{})); return transform_tensor_descriptor( in_n_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), make_pass_through_transform(Number<3>{})), make_tuple(Sequence<0, 2>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -200,28 +371,29 @@ struct TransformConvFwdToGemm else { const auto in_n_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Wi, NumGroupsToMerge), make_tuple(NStride, WiStride, GStride)); + make_tuple(N_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_)); const auto in_n_wip_c_desc = transform_tensor_descriptor( in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), make_pass_through_transform(NumGroupsToMerge)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); const auto in_n_x_wo_c_desc = transform_tensor_descriptor( in_n_wip_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Number<3>{}, Wo), - make_tuple(ConvDilationW, ConvStrideW)), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), make_pass_through_transform(NumGroupsToMerge)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); return transform_tensor_descriptor( in_n_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo, NumGroupsToMerge)), + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), make_pass_through_transform(Number<3>{})), make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -233,110 +405,108 @@ struct TransformConvFwdToGemm if constexpr(NumGroupsToMerge == 1) { const auto in_n_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); + make_tuple(N_, Wi_, C_), + make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_)); const auto in_n_wo_c_desc = transform_tensor_descriptor( in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); return transform_tensor_descriptor( in_n_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), - make_pass_through_transform(C)), + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(C_)), make_tuple(Sequence<0, 1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } else { - const auto in_n_wi_c_desc = - make_naive_tensor_descriptor(make_tuple(N, Wi, NumGroupsToMerge, C), - make_tuple(NStride, WiStride, GStride, CStride)); + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_)); const auto in_n_wo_c_desc = transform_tensor_descriptor( in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C)), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); return transform_tensor_descriptor( in_n_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo, NumGroupsToMerge)), - make_pass_through_transform(C)), + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } } else { - const index_t X = b_g_k_c_xs_lengths[3]; - const index_t ConvDilationW = conv_filter_dilations[0]; - const index_t InLeftPadW = input_left_pads[0]; - const index_t InRightPadW = input_right_pads[0]; if constexpr(NumGroupsToMerge == 1) { const auto in_n_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); + make_tuple(N_, Wi_, C_), + make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_)); const auto in_n_wip_c_desc = transform_tensor_descriptor( in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); const auto in_n_x_wo_c_desc = transform_tensor_descriptor( in_n_wip_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(X, Wo), - make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); return transform_tensor_descriptor( in_n_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), - make_merge_transform(make_tuple(X, C))), + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_merge_transform(make_tuple(X_, C_))), make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } else { - const auto in_n_wi_c_desc = - make_naive_tensor_descriptor(make_tuple(N, Wi, NumGroupsToMerge, C), - make_tuple(NStride, WiStride, GStride, CStride)); + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_)); const auto in_n_wip_c_desc = transform_tensor_descriptor( in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C)), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); const auto in_n_x_wo_c_desc = transform_tensor_descriptor( in_n_wip_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(X, Wo), - make_tuple(ConvDilationW, ConvStrideW)), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C)), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4>{})); return transform_tensor_descriptor( in_n_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo, NumGroupsToMerge)), - make_merge_transform(make_tuple(X, C))), + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(X_, C_))), make_tuple(Sequence<0, 2, 3>{}, Sequence<1, 4>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -349,57 +519,27 @@ struct TransformConvFwdToGemm is_same_v || is_same_v), bool>::type = false> - static auto - MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const index_t N) + __host__ __device__ auto MakeADescriptor_M_K() const { - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Hi = a_g_n_c_wis_lengths[3]; - const index_t Wi = a_g_n_c_wis_lengths[4]; - - const index_t Ho = c_g_n_k_wos_lengths[3]; - const index_t Wo = c_g_n_k_wos_lengths[4]; - - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; - - const index_t GStride = a_g_n_c_wis_strides[I0]; - const index_t NStride = a_g_n_c_wis_strides[I1]; - const index_t CStride = a_g_n_c_wis_strides[I2]; - const index_t HiStride = a_g_n_c_wis_strides[I3]; - const index_t WiStride = a_g_n_c_wis_strides[I4]; - if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { - const index_t NHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); if constexpr(NumGroupsToMerge == 1) { - return make_naive_tensor_descriptor(make_tuple(NHoWo, C), - make_tuple(WiStride, CStride)); + return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_), + make_tuple(WiStride_, CStrideTensorA_)); } else { const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( - make_tuple(NHoWo, NumGroupsToMerge, C), make_tuple(WiStride, GStride, CStride)); + make_tuple(NDoHoWo_, NumGroupsToMerge, C_), + make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_)); return transform_tensor_descriptor( in_gemmm_groups_gemmk_desc, - make_tuple(make_merge_transform(make_tuple(NHoWo, NumGroupsToMerge)), - make_pass_through_transform(C)), + make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), make_tuple(Sequence<0, 1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -407,73 +547,65 @@ struct TransformConvFwdToGemm else if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter3x3) { - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; - - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; - - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; if constexpr(NumGroupsToMerge == 1) { const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Hi, Wi), make_tuple(NStride, HiStride, WiStride)); + make_tuple(N_, Hi_, Wi_), make_tuple(NStrideTensorA_, HiStride_, WiStride_)); const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( in_n_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( in_n_hip_wip_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Number<3>{}, Ho), - make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(Number<3>{}, Wo), - make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_))), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{})); return transform_tensor_descriptor( in_n_y_ho_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))), make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } else { - const auto in_n_hi_wi_groups_c_desc = - make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, NumGroupsToMerge), - make_tuple(NStride, HiStride, WiStride, GStride)); + const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_)); const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( in_n_hi_wi_groups_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), make_pass_through_transform(NumGroupsToMerge)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( in_n_hip_wip_groups_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Number<3>{}, Ho), - make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(Number<3>{}, Wo), - make_tuple(ConvDilationW, ConvStrideW)), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), make_pass_through_transform(NumGroupsToMerge)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); return transform_tensor_descriptor( in_n_y_ho_x_wo_groups_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, NumGroupsToMerge)), + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))), make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -485,37 +617,39 @@ struct TransformConvFwdToGemm if constexpr(NumGroupsToMerge == 1) { const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); + make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_)); const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( in_n_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); return transform_tensor_descriptor( in_n_ho_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_pass_through_transform(C)), + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_pass_through_transform(C_)), make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } else { const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Hi, Wi, NumGroupsToMerge, C), - make_tuple(NStride, HiStride, WiStride, GStride, CStride)); + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple( + NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_)); const auto in_n_ho_wo_groups_c_desc = transform_tensor_descriptor( in_n_hi_wi_groups_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C)), + make_pass_through_transform(C_)), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple( @@ -523,55 +657,44 @@ struct TransformConvFwdToGemm return transform_tensor_descriptor( in_n_ho_wo_groups_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, NumGroupsToMerge)), - make_pass_through_transform(C)), + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } } else { - const index_t Y = b_g_k_c_xs_lengths[3]; - const index_t X = b_g_k_c_xs_lengths[4]; - - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; - - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; - - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; - if constexpr(NumGroupsToMerge == 1) { const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); + make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_)); const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( in_n_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( in_n_hip_wip_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), - make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), - make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); return transform_tensor_descriptor( in_n_y_ho_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_merge_transform(make_tuple(Y, X, C))), + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_merge_transform(make_tuple(Y_, X_, C_))), make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -579,16 +702,17 @@ struct TransformConvFwdToGemm { const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Hi, Wi, NumGroupsToMerge, C), - make_tuple(NStride, HiStride, WiStride, GStride, CStride)); + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple( + NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_)); const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( in_n_hi_wi_groups_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C)), + make_pass_through_transform(C_)), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple( @@ -596,13 +720,13 @@ struct TransformConvFwdToGemm const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( in_n_hip_wip_groups_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), - make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), - make_tuple(ConvDilationW, ConvStrideW)), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C)), + make_pass_through_transform(C_)), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, @@ -613,8 +737,8 @@ struct TransformConvFwdToGemm return transform_tensor_descriptor( in_n_y_ho_x_wo_groups_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, NumGroupsToMerge)), - make_merge_transform(make_tuple(Y, X, C))), + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(Y_, X_, C_))), make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3, 6>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -627,63 +751,27 @@ struct TransformConvFwdToGemm is_same_v || is_same_v), bool>::type = false> - static auto - MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides*/, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const index_t N) + __host__ __device__ auto MakeADescriptor_M_K() const { - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Di = a_g_n_c_wis_lengths[3]; - const index_t Hi = a_g_n_c_wis_lengths[4]; - const index_t Wi = a_g_n_c_wis_lengths[5]; - - const index_t Do = c_g_n_k_wos_lengths[3]; - const index_t Ho = c_g_n_k_wos_lengths[4]; - const index_t Wo = c_g_n_k_wos_lengths[5]; - - const index_t ConvStrideD = conv_filter_strides[0]; - const index_t ConvStrideH = conv_filter_strides[1]; - const index_t ConvStrideW = conv_filter_strides[2]; - - const index_t GStride = a_g_n_c_wis_strides[I0]; - const index_t NStride = a_g_n_c_wis_strides[I1]; - const index_t CStride = a_g_n_c_wis_strides[I2]; - const index_t DiStride = a_g_n_c_wis_strides[I3]; - const index_t HiStride = a_g_n_c_wis_strides[I4]; - const index_t WiStride = a_g_n_c_wis_strides[I5]; - if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { - const index_t NDoHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - if constexpr(NumGroupsToMerge == 1) { - return make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), - make_tuple(WiStride, CStride)); + return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_), + make_tuple(WiStride_, CStrideTensorA_)); } else { - const auto in_gemmm_groups_gemmk_desc = - make_naive_tensor_descriptor(make_tuple(NDoHoWo, NumGroupsToMerge, C), - make_tuple(WiStride, GStride, CStride)); + const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(NDoHoWo_, NumGroupsToMerge, C_), + make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_)); return transform_tensor_descriptor( in_gemmm_groups_gemmk_desc, - make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), - make_pass_through_transform(C)), + make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), make_tuple(Sequence<0, 1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -691,41 +779,30 @@ struct TransformConvFwdToGemm else if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter3x3) { - const index_t ConvDilationD = conv_filter_dilations[0]; - const index_t ConvDilationH = conv_filter_dilations[1]; - const index_t ConvDilationW = conv_filter_dilations[2]; - - const index_t InLeftPadD = input_left_pads[0]; - const index_t InLeftPadH = input_left_pads[1]; - const index_t InLeftPadW = input_left_pads[2]; - - const index_t InRightPadD = input_right_pads[0]; - const index_t InRightPadH = input_right_pads[1]; - const index_t InRightPadW = input_right_pads[2]; - if constexpr(NumGroupsToMerge == 1) { const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi), make_tuple(NStride, DiStride, HiStride, WiStride)); + make_tuple(N_, Di_, Hi_, Wi_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_)); const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Di, InLeftPadD, InRightPadD), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( in_n_hip_wip_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Number<3>{}, Do), - make_tuple(ConvDilationD, ConvStrideD)), - make_embed_transform(make_tuple(Number<3>{}, Ho), - make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(Number<3>{}, Wo), - make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Number<3>{}, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_))), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple( Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5, 6>{})); @@ -733,7 +810,7 @@ struct TransformConvFwdToGemm return transform_tensor_descriptor( in_n_z_do_y_ho_x_wo_c_desc, make_tuple( - make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))), make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -741,15 +818,15 @@ struct TransformConvFwdToGemm else { const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi, NumGroupsToMerge), - make_tuple(NStride, DiStride, HiStride, WiStride, GStride)); + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_)); const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Di, InLeftPadD, InRightPadD), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), make_pass_through_transform(NumGroupsToMerge)), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), @@ -758,13 +835,13 @@ struct TransformConvFwdToGemm const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( in_n_hip_wip_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Number<3>{}, Do), - make_tuple(ConvDilationD, ConvStrideD)), - make_embed_transform(make_tuple(Number<3>{}, Ho), - make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(Number<3>{}, Wo), - make_tuple(ConvDilationW, ConvStrideW)), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Number<3>{}, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), make_pass_through_transform(NumGroupsToMerge)), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), @@ -777,7 +854,7 @@ struct TransformConvFwdToGemm return transform_tensor_descriptor( in_n_z_do_y_ho_x_wo_c_desc, make_tuple( - make_merge_transform(make_tuple(N, Do, Ho, Wo, NumGroupsToMerge)), + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))), make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -789,16 +866,16 @@ struct TransformConvFwdToGemm if constexpr(NumGroupsToMerge == 1) { const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi, C), - make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + make_tuple(N_, Di_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_)); const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Do_), make_tuple(ConvStrideD_)), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(C_)), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple( @@ -806,25 +883,30 @@ struct TransformConvFwdToGemm return transform_tensor_descriptor( in_n_do_ho_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_pass_through_transform(C)), + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_pass_through_transform(C_)), make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } else { const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi, NumGroupsToMerge, C), - make_tuple(NStride, DiStride, HiStride, WiStride, GStride, CStride)); + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, + DiStride_, + HiStride_, + WiStride_, + GStrideTensorA_, + CStrideTensorA_)); const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Do_), make_tuple(ConvStrideD_)), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C)), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, @@ -840,43 +922,28 @@ struct TransformConvFwdToGemm return transform_tensor_descriptor( in_n_do_ho_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, NumGroupsToMerge)), - make_pass_through_transform(C)), + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), make_tuple(Sequence<0, 1, 2, 3, 4>{}, Sequence<5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } } else { - const index_t Z = b_g_k_c_xs_lengths[3]; - const index_t Y = b_g_k_c_xs_lengths[4]; - const index_t X = b_g_k_c_xs_lengths[5]; - - const index_t ConvDilationD = conv_filter_dilations[0]; - const index_t ConvDilationH = conv_filter_dilations[1]; - const index_t ConvDilationW = conv_filter_dilations[2]; - - const index_t InLeftPadD = input_left_pads[0]; - const index_t InLeftPadH = input_left_pads[1]; - const index_t InLeftPadW = input_left_pads[2]; - - const index_t InRightPadD = input_right_pads[0]; - const index_t InRightPadH = input_right_pads[1]; - const index_t InRightPadW = input_right_pads[2]; - if constexpr(NumGroupsToMerge == 1) { const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi, C), - make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + make_tuple(N_, Di_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_)); const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Di, InLeftPadD, InRightPadD), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple( @@ -884,14 +951,14 @@ struct TransformConvFwdToGemm const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( in_n_hip_wip_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Z, Do), - make_tuple(ConvDilationD, ConvStrideD)), - make_embed_transform(make_tuple(Y, Ho), - make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), - make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Z_, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, @@ -902,25 +969,30 @@ struct TransformConvFwdToGemm return transform_tensor_descriptor( in_n_z_do_y_ho_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_merge_transform(make_tuple(Z, Y, X, C))), + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_merge_transform(make_tuple(Z_, Y_, X_, C_))), make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } else { const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi, NumGroupsToMerge, C), - make_tuple(NStride, DiStride, HiStride, WiStride, GStride, CStride)); + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, + DiStride_, + HiStride_, + WiStride_, + GStrideTensorA_, + CStrideTensorA_)); const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Di, InLeftPadD, InRightPadD), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C)), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, @@ -936,15 +1008,15 @@ struct TransformConvFwdToGemm const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( in_n_hip_wip_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Z, Do), - make_tuple(ConvDilationD, ConvStrideD)), - make_embed_transform(make_tuple(Y, Ho), - make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), - make_tuple(ConvDilationW, ConvStrideW)), + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Z_, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C)), + make_pass_through_transform(C_)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, @@ -960,8 +1032,9 @@ struct TransformConvFwdToGemm return transform_tensor_descriptor( in_n_z_do_y_ho_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, NumGroupsToMerge)), - make_merge_transform(make_tuple(Z, Y, X, C))), + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(Z_, Y_, X_, C_))), make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5, 8>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -973,19 +1046,8 @@ struct TransformConvFwdToGemm is_same_v || is_same_v, bool>::type = false> - static auto MakeBDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides) + __host__ __device__ auto MakeBDescriptor_N_K() const { - const index_t K = b_g_k_c_xs_lengths[1]; - const index_t C = b_g_k_c_xs_lengths[2]; - - const index_t YX = ck::accumulate_n( - b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - const index_t GStride = b_g_k_c_xs_strides[I0]; - const index_t KStride = b_g_k_c_xs_strides[I1]; - const index_t CStride = b_g_k_c_xs_strides[I2]; - if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter3x3) { @@ -996,17 +1058,17 @@ struct TransformConvFwdToGemm if constexpr(NumGroupsToMerge == 1) { - return make_naive_tensor_descriptor_packed(make_tuple(K, FilterSizeNumType{})); + return make_naive_tensor_descriptor_packed(make_tuple(K_, FilterSizeNumType{})); } else { const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( - make_tuple(K, NumGroupsToMerge, FilterSizeNumType{}), - make_tuple(KStride, GStride, CStride)); + make_tuple(K_, NumGroupsToMerge, FilterSizeNumType{}), + make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_)); return transform_tensor_descriptor( wei_gemmn_groups_gemmk_desc, - make_tuple(make_merge_transform(make_tuple(K, NumGroupsToMerge)), + make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)), make_pass_through_transform(FilterSizeNumType{})), make_tuple(Sequence<0, 1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -1016,16 +1078,17 @@ struct TransformConvFwdToGemm { if constexpr(NumGroupsToMerge == 1) { - return make_naive_tensor_descriptor_packed(make_tuple(K, YX * C)); + return make_naive_tensor_descriptor_packed(make_tuple(K_, ZYX_ * C_)); } else { const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( - make_tuple(K, NumGroupsToMerge, YX * C), make_tuple(KStride, GStride, CStride)); + make_tuple(K_, NumGroupsToMerge, ZYX_ * C_), + make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_)); return transform_tensor_descriptor( wei_gemmn_groups_gemmk_desc, - make_tuple(make_merge_transform(make_tuple(K, NumGroupsToMerge)), - make_pass_through_transform(YX * C)), + make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)), + make_pass_through_transform(ZYX_ * C_)), make_tuple(Sequence<0, 1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -1041,25 +1104,14 @@ struct TransformConvFwdToGemm is_same_v || is_same_v, bool>::type = false> - static auto MakeBDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides) + __host__ __device__ auto MakeBDescriptor_N_K() const { - const index_t K = b_g_k_c_xs_lengths[1]; - const index_t C = b_g_k_c_xs_lengths[2]; - - const index_t YX = ck::accumulate_n( - b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - const index_t KStride = b_g_k_c_xs_strides[1]; - const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial]; - const auto CStride = I1; - const auto wei_k_yx_c_desc = make_naive_tensor_descriptor( - make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride)); + make_tuple(K_, ZYX_, C_), make_tuple(KStrideTensorB_, XStride_, CStrideTensorB_)); const auto wei_gemmn_gemmk_desc = transform_tensor_descriptor( wei_k_yx_c_desc, - make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))), + make_tuple(make_pass_through_transform(K_), make_merge_transform(make_tuple(ZYX_, C_))), make_tuple(Sequence<0>{}, Sequence<1, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -1071,24 +1123,14 @@ struct TransformConvFwdToGemm is_same_v || is_same_v, bool>::type = false> - static auto - MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */, - const index_t N) + __host__ __device__ auto MakeCDescriptor_M_N() const { - const index_t K = c_g_n_k_wos_lengths[2]; - - const index_t NHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K)); - - return out_gemmm_gemmn_desc; + return make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo_, K_)); } template < typename CLayout, + typename std::enable_if || is_same_v || is_same_v || @@ -1096,39 +1138,28 @@ struct TransformConvFwdToGemm is_same_v || is_same_v, bool>::type = false> - static auto MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& c_g_n_k_wos_strides, - const index_t N) + __host__ __device__ auto MakeCDescriptor_M_N() const { - const index_t K = c_g_n_k_wos_lengths[2]; - - const index_t KStride = I1; - const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2]; - const index_t GStride = c_g_n_k_wos_strides[0]; - - const index_t NHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); if constexpr(NumGroupsToMerge == 1) { - return make_naive_tensor_descriptor(make_tuple(NHoWo, K), - make_tuple(WoStride, KStride)); + return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, K_), + make_tuple(WoStride_, KStrideTensorC_)); } else { - const auto nhwo_groups_k_1_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, NumGroupsToMerge, K, 1), - make_tuple(WoStride, GStride, KStride, GStride)); + const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor( + make_tuple(NDoHoWo_, NumGroupsToMerge, K_, 1), + make_tuple(WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_)); // Padd 1 to NumGroupsToMerge const auto padded_desc = transform_tensor_descriptor( nhwo_groups_k_1_desc, - make_tuple(make_pass_through_transform(NHoWo), + make_tuple(make_pass_through_transform(NDoHoWo_), make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(K), + make_pass_through_transform(K_), make_pad_transform(1, 0, NumGroupsToMerge - 1)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - // We need only matrices from diagonal. Xor returns 0 for the same + // We need only matrices from diagonal. X_or returns 0 for the same // values. So if matrices is not on diagonal then it will be stored in padding. // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || @@ -1136,16 +1167,16 @@ struct TransformConvFwdToGemm NumGroupsToMerge == 32 || NumGroupsToMerge == 64); const auto unmerged_padded_desc = transform_tensor_descriptor( padded_desc, - make_tuple(make_pass_through_transform(NHoWo), + make_tuple(make_pass_through_transform(NDoHoWo_), make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), - make_pass_through_transform(K)), + make_pass_through_transform(K_)), make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); // Merge To M, N return transform_tensor_descriptor( unmerged_padded_desc, - make_tuple(make_merge_transform(make_tuple(NHoWo, NumGroupsToMerge)), - make_merge_transform(make_tuple(K, NumGroupsToMerge))), + make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(K_, NumGroupsToMerge))), make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -1155,542 +1186,34 @@ struct TransformConvFwdToGemm template , bool>::type = false> - static auto MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& c_g_n_k_wos_strides, - const index_t N) + __host__ __device__ auto MakeCDescriptor_M_N() const { - const index_t K = c_g_n_k_wos_lengths[2]; - const index_t KStride = c_g_n_k_wos_strides[2]; - - const index_t NHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - const auto out_gemmm_gemmn_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, KStride)); + make_naive_tensor_descriptor(make_tuple(NDoHoWo_, K_), make_tuple(I0, KStrideTensorC_)); return out_gemmm_gemmn_desc; } - // Overloaded functions for hipRTC purposes - template || - is_same_v || - is_same_v), - bool>::type = false> - __host__ __device__ static auto - MakeADescriptor_M_K(const ck::Array& a_g_n_c_wis_lengths, - const ck::Array& a_g_n_c_wis_strides, - const ck::Array& b_g_k_c_xs_lengths, - const ck::Array& /* b_g_k_c_xs_strides */, - const ck::Array& c_g_n_k_wos_lengths, - const ck::Array& /* c_g_n_k_wos_strides */, - const ck::Array& conv_filter_strides, - const ck::Array& conv_filter_dilations, - const ck::Array& input_left_pads, - const ck::Array& input_right_pads) - { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Wi = a_g_n_c_wis_lengths[3]; - - const index_t Wo = c_g_n_k_wos_lengths[3]; - - const index_t ConvStrideW = conv_filter_strides[0]; - - if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) - { - const index_t NHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - // This is different - const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; - const auto CStride = I1; - - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride)); - - return in_gemmm_gemmk_desc; - } - else if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Pad0) - { - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t WiStride = a_g_n_c_wis_strides[3]; - const auto CStride = I1; - - const auto in_n_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); - - const auto in_n_wo_c_desc = transform_tensor_descriptor( - in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; - } - else - { - const index_t X = b_g_k_c_xs_lengths[3]; - const index_t ConvDilationW = conv_filter_dilations[0]; - const index_t InLeftPadW = input_left_pads[0]; - const index_t InRightPadW = input_right_pads[0]; - - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t WiStride = a_g_n_c_wis_strides[3]; - const auto CStride = I1; - - const auto in_n_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); - - const auto in_n_wip_c_desc = transform_tensor_descriptor( - in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto in_n_x_wo_c_desc = transform_tensor_descriptor( - in_n_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), - make_merge_transform(make_tuple(X, C))), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; - } - } - - template || - is_same_v || - is_same_v), - bool>::type = false> - __host__ __device__ static auto - MakeADescriptor_M_K(const ck::Array& a_g_n_c_wis_lengths, - const ck::Array& a_g_n_c_wis_strides, - const ck::Array& b_g_k_c_xs_lengths, - const ck::Array& /* b_g_k_c_xs_strides */, - const ck::Array& c_g_n_k_wos_lengths, - const ck::Array& /* c_g_n_k_wos_strides */, - const ck::Array& conv_filter_strides, - const ck::Array& conv_filter_dilations, - const ck::Array& input_left_pads, - const ck::Array& input_right_pads) - { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Hi = a_g_n_c_wis_lengths[3]; - const index_t Wi = a_g_n_c_wis_lengths[4]; - - const index_t Ho = c_g_n_k_wos_lengths[3]; - const index_t Wo = c_g_n_k_wos_lengths[4]; - - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; - - if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) - { - const index_t NHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - // This is different - const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; - const auto CStride = I1; - - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride)); - - return in_gemmm_gemmk_desc; - } - else if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Pad0) - { - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t HiStride = a_g_n_c_wis_strides[3]; - const index_t WiStride = a_g_n_c_wis_strides[4]; - const auto CStride = I1; - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); - - const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( - in_n_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_ho_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; - } - else - { - const index_t Y = b_g_k_c_xs_lengths[3]; - const index_t X = b_g_k_c_xs_lengths[4]; - - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; - - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; - - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; - - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t HiStride = a_g_n_c_wis_strides[3]; - const index_t WiStride = a_g_n_c_wis_strides[4]; - const auto CStride = I1; - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); - - const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( - in_n_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( - in_n_hip_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_merge_transform(make_tuple(Y, X, C))), - make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; - } - } - - template || - is_same_v || - is_same_v), - bool>::type = false> - static auto - MakeADescriptor_M_K(const ck::Array& a_g_n_c_wis_lengths, - const ck::Array& a_g_n_c_wis_strides, - const ck::Array& b_g_k_c_xs_lengths, - const ck::Array& /* b_g_k_c_xs_strides */, - const ck::Array& c_g_n_k_wos_lengths, - const ck::Array& /* c_g_n_k_wos_strides */, - const ck::Array& conv_filter_strides, - const ck::Array& conv_filter_dilations, - const ck::Array& input_left_pads, - const ck::Array& input_right_pads) - { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Di = a_g_n_c_wis_lengths[3]; - const index_t Hi = a_g_n_c_wis_lengths[4]; - const index_t Wi = a_g_n_c_wis_lengths[5]; - - const index_t Do = c_g_n_k_wos_lengths[3]; - const index_t Ho = c_g_n_k_wos_lengths[4]; - const index_t Wo = c_g_n_k_wos_lengths[5]; - - const index_t ConvStrideD = conv_filter_strides[0]; - const index_t ConvStrideH = conv_filter_strides[1]; - const index_t ConvStrideW = conv_filter_strides[2]; - - if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) - { - const index_t NDoHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - // This is different - const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; - const auto CStride = I1; - - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), make_tuple(WiStride, CStride)); - - return in_gemmm_gemmk_desc; - } - else if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Pad0) - { - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t DiStride = a_g_n_c_wis_strides[3]; - const index_t HiStride = a_g_n_c_wis_strides[4]; - const index_t WiStride = a_g_n_c_wis_strides[5]; - const auto CStride = I1; - - const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi, C), - make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); - - const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_do_ho_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; - } - else - { - const index_t Z = b_g_k_c_xs_lengths[3]; - const index_t Y = b_g_k_c_xs_lengths[4]; - const index_t X = b_g_k_c_xs_lengths[5]; - - const index_t ConvDilationD = conv_filter_dilations[0]; - const index_t ConvDilationH = conv_filter_dilations[1]; - const index_t ConvDilationW = conv_filter_dilations[2]; - - const index_t InLeftPadD = input_left_pads[0]; - const index_t InLeftPadH = input_left_pads[1]; - const index_t InLeftPadW = input_left_pads[2]; - - const index_t InRightPadD = input_right_pads[0]; - const index_t InRightPadH = input_right_pads[1]; - const index_t InRightPadW = input_right_pads[2]; - - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t DiStride = a_g_n_c_wis_strides[3]; - const index_t HiStride = a_g_n_c_wis_strides[4]; - const index_t WiStride = a_g_n_c_wis_strides[5]; - const auto CStride = I1; - - const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi, C), - make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); - - const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Di, InLeftPadD, InRightPadD), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( - in_n_hip_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, - Sequence<1, 2>{}, - Sequence<3, 4>{}, - Sequence<5, 6>{}, - Sequence<7>{})); - - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_z_do_y_ho_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_merge_transform(make_tuple(Z, Y, X, C))), - make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; - } - } - - template || - is_same_v || - is_same_v, - bool>::type = false> - __host__ __device__ static auto - MakeBDescriptor_N_K(const ck::Array& b_g_k_c_xs_lengths, - const ck::Array& /* b_g_k_c_xs_strides */) - { - const index_t K = b_g_k_c_xs_lengths[1]; - const index_t C = b_g_k_c_xs_lengths[2]; - - const index_t YX = - mult_accumulate_n(b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1); - - const auto wei_gemmn_gemmk_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, YX * C)); - - return wei_gemmn_gemmk_desc; - } - - template < - typename BLayout, - typename std::enable_if || - is_same_v || - is_same_v || - is_same_v || - is_same_v || - is_same_v, - bool>::type = false> - __host__ __device__ static auto - MakeBDescriptor_N_K(const ck::Array& b_g_k_c_xs_lengths, - const ck::Array& b_g_k_c_xs_strides) - { - const index_t K = b_g_k_c_xs_lengths[1]; - const index_t C = b_g_k_c_xs_lengths[2]; - - const index_t YX = - mult_accumulate_n(b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1); - - const index_t KStride = b_g_k_c_xs_strides[1]; - const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial]; - const auto CStride = I1; - - const auto wei_k_yx_c_desc = make_naive_tensor_descriptor( - make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride)); - - const auto wei_gemmn_gemmk_desc = transform_tensor_descriptor( - wei_k_yx_c_desc, - make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return wei_gemmn_gemmk_desc; - } - - template || - is_same_v || - is_same_v, - bool>::type = false> - __host__ __device__ static auto - MakeCDescriptor_M_N(const ck::Array& c_g_n_k_wos_lengths, - const ck::Array& /* c_g_n_k_wos_strides */) - { - const index_t N = c_g_n_k_wos_lengths[1]; - const index_t K = c_g_n_k_wos_lengths[2]; - - const index_t NHoWo = - N * mult_accumulate_n(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1); - - const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K)); - - return out_gemmm_gemmn_desc; - } - - template < - typename CLayout, - typename std::enable_if || - is_same_v || - is_same_v || - is_same_v || - is_same_v || - is_same_v, - bool>::type = false> - __host__ __device__ static auto - MakeCDescriptor_M_N(const ck::Array& c_g_n_k_wos_lengths, - const ck::Array& c_g_n_k_wos_strides) - { - const index_t N = c_g_n_k_wos_lengths[1]; - const index_t K = c_g_n_k_wos_lengths[2]; - - const auto KStride = I1; - const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2]; - - const index_t NHoWo = - N * mult_accumulate_n(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1); - - const auto out_gemmm_gemmn_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride)); - - return out_gemmm_gemmn_desc; - } - - // for output bias - template , - bool>::type = false> - __host__ __device__ static auto - MakeCDescriptor_M_N(const ck::Array& c_g_n_k_wos_lengths, - const ck::Array& c_g_n_k_wos_strides) - { - const index_t N = c_g_n_k_wos_lengths[1]; - const index_t K = c_g_n_k_wos_lengths[2]; - const index_t KStride = c_g_n_k_wos_strides[2]; - - const index_t NHoWo = - N * mult_accumulate_n(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1); - - const auto out_gemmm_gemmn_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, KStride)); - - return out_gemmm_gemmn_desc; - } + public: + index_t N_; + + private: + const index_t Di_, Hi_, Wi_; + const index_t Do_, Ho_, Wo_; + const index_t Z_, Y_, X_; + const index_t K_, C_; + const index_t DiStride_, HiStride_, WiStride_; + const index_t WoStride_; + const index_t XStride_; + const index_t CStrideTensorA_, CStrideTensorB_, KStrideTensorB_, KStrideTensorC_; + const index_t NStrideTensorA_; + const index_t GStrideTensorA_, GStrideTensorB_, GStrideTensorC_; + const index_t ConvStrideD_, ConvStrideH_, ConvStrideW_; + const index_t ConvDilationD_, ConvDilationH_, ConvDilationW_; + const index_t InLeftPadD_, InLeftPadH_, InLeftPadW_; + const index_t InRightPadD_, InRightPadH_, InRightPadW_; + const index_t ZYX_; + index_t NDoHoWo_; }; // wrapper class to call member functions on TransformConvToGemm struct at runtime @@ -1702,26 +1225,22 @@ struct TransformConv template auto - transform_func(ck::Array out_lengths, - ck::Array out_strides, - TransformConvFwdToGemm conv_fwd_to_gemm) + transform_func(TransformConvFwdToGemm conv_fwd_to_gemm) { if(NDimSpatial == 2) { return conv_fwd_to_gemm - .template MakeCDescriptor_M_N(out_lengths, - out_strides); + .template MakeCDescriptor_M_N(); } else if(NDimSpatial == 3) { return conv_fwd_to_gemm - .template MakeCDescriptor_M_N(out_lengths, - out_strides); + .template MakeCDescriptor_M_N(); } else if(NDimSpatial == 1) { - return conv_fwd_to_gemm.template MakeCDescriptor_M_N( - out_lengths, out_strides); + return conv_fwd_to_gemm + .template MakeCDescriptor_M_N(); } } };