From feca6e57f9765af28d19ef294d79bb2c6e5b931c Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 31 Aug 2022 11:27:11 -0500 Subject: [PATCH] conv+conv (1x1 only) example using gemm+gemm (#393) * refactor conv * add conv+conv example, 1x1 only [ROCm/composable_kernel commit: 4df6d93f6092b4ffe6878fceeec15d4c70c94d62] --- example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp | 113 +-- ...uped_convnd_fwd_bias_relu_add_xdl_fp16.cpp | 6 +- .../41_grouped_conv_conv_fwd/CMakeLists.txt | 1 + .../grouped_conv_conv_fwd_common.hpp | 257 +++++ .../grouped_conv_conv_fwd_xdl_fp16.cpp | 204 ++++ example/CMakeLists.txt | 1 + .../device_batched_gemm_gemm_xdl_cshuffle.hpp | 9 + .../device_batched_gemm_multi_d_xdl.hpp | 1 + .../device_gemm_multiple_d_xdl_cshuffle.hpp | 4 +- .../device_grouped_conv_fwd_multiple_d.hpp | 2 +- ...ouped_conv_fwd_multiple_d_xdl_cshuffle.hpp | 922 +----------------- .../gpu/device/matrix_padder.hpp | 140 +-- .../transform_conv_fwd_to_gemm.hpp | 870 +++++++++++++++++ .../library/utility/convolution_parameter.hpp | 49 +- 14 files changed, 1524 insertions(+), 1055 deletions(-) create mode 100644 example/41_grouped_conv_conv_fwd/CMakeLists.txt create mode 100644 example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_common.hpp create mode 100644 example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp create mode 100644 include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp index c4df64abe4..a8432c5892 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -76,8 +76,6 @@ using DeviceGroupedConvNDFwdInstance = int main(int argc, char* argv[]) { - namespace ctc = ck::tensor_layout::convolution; - print_helper_msg(); bool do_verification = true; @@ -111,11 +109,12 @@ int main(int argc, char* argv[]) const auto wei_element_op = WeiElementOp{}; const auto out_element_op = OutElementOp{}; - if(conv_param.num_dim_spatial_ == 1) - { - using InLayout = ctc::GNWC; - using WeiLayout = ctc::GKXC; - using OutLayout = ctc::GNWK; + const auto run = [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto out_layout) { + constexpr ck::index_t ndim_spatial_value = ndim_spatial.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); const auto in_g_n_c_wis_desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( @@ -130,97 +129,39 @@ int main(int argc, char* argv[]) conv_param); return run_grouped_conv_fwd< - 1, + ndim_spatial_value, InDataType, WeiDataType, OutDataType, InElementOp, WeiElementOp, OutElementOp, - DeviceGroupedConvNDFwdInstance<1, InLayout, WeiLayout, OutLayout>>(do_verification, - init_method, - time_kernel, - conv_param, - in_g_n_c_wis_desc, - wei_g_k_c_xs_desc, - out_g_n_k_wos_desc, - in_element_op, - wei_element_op, - out_element_op); + DeviceGroupedConvNDFwdInstance>( + do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + }; + + namespace ctc = ck::tensor_layout::convolution; + + if(conv_param.num_dim_spatial_ == 1) + { + run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ctc::GNWK{}); } else if(conv_param.num_dim_spatial_ == 2) { - using InLayout = ctc::GNHWC; - using WeiLayout = ctc::GKYXC; - using OutLayout = ctc::GNHWK; - - const auto in_g_n_c_wis_desc = - ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( - conv_param); - - const auto wei_g_k_c_xs_desc = - ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( - conv_param); - - const auto out_g_n_k_wos_desc = - ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( - conv_param); - - return run_grouped_conv_fwd< - 2, - InDataType, - WeiDataType, - OutDataType, - InElementOp, - WeiElementOp, - OutElementOp, - DeviceGroupedConvNDFwdInstance<2, InLayout, WeiLayout, OutLayout>>(do_verification, - init_method, - time_kernel, - conv_param, - in_g_n_c_wis_desc, - wei_g_k_c_xs_desc, - out_g_n_k_wos_desc, - in_element_op, - wei_element_op, - out_element_op); + run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ctc::GNHWK{}); } else if(conv_param.num_dim_spatial_ == 3) { - using InLayout = ctc::GNDHWC; - using WeiLayout = ctc::GKZYXC; - using OutLayout = ctc::GNDHWK; - - const auto in_g_n_c_wis_desc = - ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( - conv_param); - - const auto wei_g_k_c_xs_desc = - ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( - conv_param); - - const auto out_g_n_k_wos_desc = - ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( - conv_param); - - return run_grouped_conv_fwd< - 3, - InDataType, - WeiDataType, - OutDataType, - InElementOp, - WeiElementOp, - OutElementOp, - DeviceGroupedConvNDFwdInstance<3, InLayout, WeiLayout, OutLayout>>(do_verification, - init_method, - time_kernel, - conv_param, - in_g_n_c_wis_desc, - wei_g_k_c_xs_desc, - out_g_n_k_wos_desc, - in_element_op, - wei_element_op, - out_element_op); + run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ctc::GNDHWK{}); } return 0; diff --git a/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_fp16.cpp b/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_fp16.cpp index 8846633982..2fb2681ea6 100644 --- a/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_fp16.cpp +++ b/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_fp16.cpp @@ -163,9 +163,9 @@ int main(int argc, char* argv[]) {conv_param.G_, conv_param.N_, conv_param.K_, conv_param.output_spatial_lengths_[0]}, { conv_param.K_, // g - 0, // k - 1, // c - 0 // x + 0, // n + 1, // k + 0 // wo }); const auto residual_g_n_k_wos_desc = HostTensorDescriptor( diff --git a/example/41_grouped_conv_conv_fwd/CMakeLists.txt b/example/41_grouped_conv_conv_fwd/CMakeLists.txt new file mode 100644 index 0000000000..ef88eca12c --- /dev/null +++ b/example/41_grouped_conv_conv_fwd/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_common.hpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_common.hpp new file mode 100644 index 0000000000..5ad1ff9576 --- /dev/null +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_common.hpp @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +template +int run_grouped_conv_conv_fwd(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv0_param, + const ck::utils::conv::ConvParam& conv1_param, + const HostTensorDescriptor& in0_g_n_c_wis_desc, + const HostTensorDescriptor& wei0_g_k_c_xs_desc, + const HostTensorDescriptor& out0_g_n_k_wos_desc, + const HostTensorDescriptor& wei1_g_k_c_xs_desc, + const HostTensorDescriptor& out1_g_n_k_wos_desc, + const In0ElementOp& in0_element_op, + const Wei0ElementOp& wei0_element_op, + const Wei1ElementOp& wei1_element_op, + const Out0ElementOp& out0_element_op, + const Out1ElementOp& out1_element_op) +{ + Tensor in0(in0_g_n_c_wis_desc); + Tensor wei0(wei0_g_k_c_xs_desc); + Tensor wei1(wei1_g_k_c_xs_desc); + Tensor out1_host(out1_g_n_k_wos_desc); + Tensor out1_device(out1_g_n_k_wos_desc); + + std::cout << "in0: " << in0.mDesc << std::endl; + std::cout << "wei0: " << wei0.mDesc << std::endl; + std::cout << "wei1: " << wei1.mDesc << std::endl; + std::cout << "out1: " << out1_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in0.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei0.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei1.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in0.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei0.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + wei1.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in0_device_buf(sizeof(In0DataType) * in0.mDesc.GetElementSpaceSize()); + DeviceMem wei0_device_buf(sizeof(Wei0DataType) * wei0.mDesc.GetElementSpaceSize()); + DeviceMem wei1_device_buf(sizeof(Wei1DataType) * wei1.mDesc.GetElementSpaceSize()); + DeviceMem out1_device_buf(sizeof(Out1DataType) * out1_device.mDesc.GetElementSpaceSize()); + + in0_device_buf.ToDevice(in0.mData.data()); + wei0_device_buf.ToDevice(wei0.mData.data()); + wei1_device_buf.ToDevice(wei1.mData.data()); + + std::array a0_g_n_c_wis_lengths{}; + std::array a0_g_n_c_wis_strides{}; + std::array b0_g_k_c_xs_lengths{}; + std::array b0_g_k_c_xs_strides{}; + std::array b1_g_k_c_xs_lengths{}; + std::array b1_g_k_c_xs_strides{}; + std::array e1_g_n_k_wos_lengths{}; + std::array e1_g_n_k_wos_strides{}; + std::array conv0_filter_strides{}; + std::array conv0_filter_dilations{}; + std::array input0_left_pads{}; + std::array input0_right_pads{}; + std::array conv1_filter_strides{}; + std::array conv1_filter_dilations{}; + std::array input1_left_pads{}; + std::array input1_right_pads{}; + + auto copy = [](auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; + + copy(in0_g_n_c_wis_desc.GetLengths(), a0_g_n_c_wis_lengths); + copy(in0_g_n_c_wis_desc.GetStrides(), a0_g_n_c_wis_strides); + copy(wei0_g_k_c_xs_desc.GetLengths(), b0_g_k_c_xs_lengths); + copy(wei0_g_k_c_xs_desc.GetStrides(), b0_g_k_c_xs_strides); + copy(wei1_g_k_c_xs_desc.GetLengths(), b1_g_k_c_xs_lengths); + copy(wei1_g_k_c_xs_desc.GetStrides(), b1_g_k_c_xs_strides); + copy(out1_g_n_k_wos_desc.GetLengths(), e1_g_n_k_wos_lengths); + copy(out1_g_n_k_wos_desc.GetStrides(), e1_g_n_k_wos_strides); + copy(conv0_param.conv_filter_strides_, conv0_filter_strides); + copy(conv0_param.conv_filter_dilations_, conv0_filter_dilations); + copy(conv0_param.input_left_pads_, input0_left_pads); + copy(conv0_param.input_right_pads_, input0_right_pads); + copy(conv1_param.conv_filter_strides_, conv1_filter_strides); + copy(conv1_param.conv_filter_dilations_, conv1_filter_dilations); + copy(conv1_param.input_left_pads_, input1_left_pads); + copy(conv1_param.input_right_pads_, input1_right_pads); + +#if 1 + // do Conv using GEMM, only works for 1x1 conv for now + const ck::index_t gemm_batch = a0_g_n_c_wis_lengths[0]; + + const ck::index_t gemm0_m_length = + e1_g_n_k_wos_lengths[1] * std::accumulate(e1_g_n_k_wos_lengths.begin() + 3, + e1_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, + ck::index_t{1}, + std::multiplies{}); + + const ck::index_t gemm0_n_length = b0_g_k_c_xs_lengths[1]; + + const ck::index_t gemm0_k_length = + std::accumulate(b0_g_k_c_xs_lengths.begin() + 2, + b0_g_k_c_xs_lengths.begin() + 2 + NDimSpatial + 1, + ck::index_t{1}, + std::multiplies{}); + + const ck::index_t gemm1_n_length = b1_g_k_c_xs_lengths[1]; + + // + const ck::index_t a0_stride = a0_g_n_c_wis_strides[2 + NDimSpatial]; + const ck::index_t b0_stride = b0_g_k_c_xs_strides[2 + NDimSpatial]; + const ck::index_t b1_stride = b1_g_k_c_xs_strides[2 + NDimSpatial]; + const ck::index_t e1_stride = e1_g_n_k_wos_strides[2 + NDimSpatial]; + + // + const ck::index_t a0_batch_stride = a0_g_n_c_wis_strides[0]; + const ck::index_t b0_batch_stride = b0_g_k_c_xs_strides[0]; + const ck::index_t b1_batch_stride = b1_g_k_c_xs_strides[0]; + const ck::index_t e1_batch_stride = e1_g_n_k_wos_strides[0]; + + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(static_cast(in0_device_buf.GetDeviceBuffer()), + static_cast(wei0_device_buf.GetDeviceBuffer()), + static_cast(wei1_device_buf.GetDeviceBuffer()), + static_cast(out1_device_buf.GetDeviceBuffer()), + gemm0_m_length, + gemm0_n_length, + gemm0_k_length, + gemm1_n_length, + gemm_batch, + a0_stride, + b0_stride, + b1_stride, + e1_stride, + a0_batch_stride, + b0_batch_stride, + b1_batch_stride, + e1_batch_stride, + in0_element_op, + wei0_element_op, + out0_element_op, + wei1_element_op, + out1_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv0_param.GetFlops() + conv1_param.GetFlops(); + std::size_t num_btype = conv0_param.template GetInputByte() + + conv0_param.template GetWeightByte() + + conv1_param.template GetWeightByte() + + conv1_param.template GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << device_op.GetTypeString() << std::endl; +#endif + + if(do_verification) + { + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + Tensor out0_host(out0_g_n_k_wos_desc); + + auto ref_conv0 = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_conv1 = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_conv0_invoker = ref_conv0.MakeInvoker(); + auto ref_conv1_invoker = ref_conv1.MakeInvoker(); + + auto ref_conv0_argument = ref_conv0.MakeArgument(in0, + wei0, + out0_host, + conv0_param.conv_filter_strides_, + conv0_param.conv_filter_dilations_, + conv0_param.input_left_pads_, + conv0_param.input_right_pads_, + in0_element_op, + wei0_element_op, + out0_element_op); + + auto ref_conv1_argument = ref_conv1.MakeArgument(out0_host, + wei1, + out1_host, + conv1_param.conv_filter_strides_, + conv1_param.conv_filter_dilations_, + conv1_param.input_left_pads_, + conv1_param.input_right_pads_, + out0_element_op, + wei1_element_op, + out1_element_op); + + ref_conv0_invoker.Run(ref_conv0_argument); + ref_conv1_invoker.Run(ref_conv1_argument); + + out1_device_buf.FromDevice(out1_device.mData.data()); + + return ck::utils::check_err( + out1_device.mData, out1_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp new file mode 100644 index 0000000000..1a8a6817f2 --- /dev/null +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp @@ -0,0 +1,204 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "grouped_conv_conv_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using In0DataType = ck::half_t; +using Wei0DataType = ck::half_t; +using Acc0DataType = float; +using Wei1DataType = ck::half_t; +using Acc1DataType = float; +using C1ShuffleDataType = float; +using Out1DataType = ck::half_t; + +template +using S = ck::Sequence; + +using In0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Wei0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Wei1ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Out0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Out1ElementOp = ck::tensor_operation::element_wise::UnaryConvert; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceBatchedGemmGemmInstance = + ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle< + Row, // ALayout + Col, // B0Layout + Col, // B1Layout + Row, // CLayout + In0DataType, // ADataType, + Wei0DataType, // B0DataType, + Wei1DataType, // B1DataType, + Out1DataType, // CDataType, + Acc0DataType, // AccDataType, + C1ShuffleDataType, // CShuffleDataType, + In0ElementOp, // AElementOp, + Wei0ElementOp, // B0ElementOp, + Out0ElementOp, // Acc0ElementOp, + Wei1ElementOp, // B1ElementOp, + Out1ElementOp, // CElementOp, + GemmDefault, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 128, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 4, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 4, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // B1BlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 4, + 4, + true, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv0_param{ + 2, 1, 128, 512, 128, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + + ck::utils::conv::ConvParam conv1_param{ + 2, 1, 128, 128, 512, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + const auto in0_element_op = In0ElementOp{}; + const auto wei0_element_op = Wei0ElementOp{}; + const auto wei1_element_op = Wei1ElementOp{}; + const auto out0_element_op = Out0ElementOp{}; + const auto out1_element_op = Out1ElementOp{}; + + const auto run = [&](auto ndim_spatial, + auto in0_layout, + auto wei0_layout, + auto wei1_layout, + auto out1_layout) { + constexpr ck::index_t ndim_spatial_value = ndim_spatial.value; + + using In0Layout = decltype(in0_layout); + using Wei0Layout = decltype(wei0_layout); + using Wei1Layout = decltype(wei1_layout); + using Out1Layout = decltype(out1_layout); + + const auto in0_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv0_param); + + const auto wei0_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv0_param); + + // out0 doesn't physical exist, any layout for host verification is OK + const auto out0_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv0_param); + + const auto wei1_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv1_param); + + const auto out1_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv1_param); + + return run_grouped_conv_conv_fwd(do_verification, + init_method, + time_kernel, + conv0_param, + conv1_param, + in0_g_n_c_wis_desc, + wei0_g_k_c_xs_desc, + out0_g_n_k_wos_desc, + wei1_g_k_c_xs_desc, + out1_g_n_k_wos_desc, + in0_element_op, + wei0_element_op, + wei1_element_op, + out0_element_op, + out1_element_op); + }; + + namespace ctc = ck::tensor_layout::convolution; + + if(conv0_param.num_dim_spatial_ == 1) + { + run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ctc::GKXC{}, ctc::GNWK{}); + } + else if(conv0_param.num_dim_spatial_ == 2) + { + run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ctc::GKYXC{}, ctc::GNHWK{}); + } + else if(conv0_param.num_dim_spatial_ == 3) + { + run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ctc::GKZYXC{}, ctc::GNDHWK{}); + } + + return 0; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 4324c92e10..9b1ba1a554 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -50,3 +50,4 @@ add_subdirectory(32_batched_gemm_scale_softmax_gemm) add_subdirectory(33_multiple_reduce) add_subdirectory(34_batchnorm) add_subdirectory(35_splitK_gemm) +add_subdirectory(41_grouped_conv_conv_fwd) diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp index 2146ca4562..9346c9b826 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -16,6 +16,7 @@ #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" namespace ck { namespace tensor_operation { @@ -464,6 +465,14 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm; - using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; - // Argument struct Argument : public BaseArgument { @@ -391,7 +389,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD +#include #include "ck/tensor_operation/gpu/device/device_base.hpp" diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp index 936ac25d09..2e22aee225 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp @@ -13,8 +13,9 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" @@ -296,922 +297,71 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; + static constexpr auto conv_to_gemm_transformer = + TransformConvFwdToGemm{}; + static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; - template , - bool>::type = false> - static auto - MakeAGridDescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& /* a_g_n_c_wis_strides */, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& e_g_n_k_wos_lengths, - const std::array& /* e_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) - { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Wi = a_g_n_c_wis_lengths[3]; - - const index_t Wo = e_g_n_k_wos_lengths[3]; - - const index_t ConvStrideW = conv_filter_strides[0]; - - if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) - { - const index_t NWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3, - e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, - index_t{1}, - std::multiplies()); - - const auto in_gemmmraw_gemmk_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(NWo, C)); - - const auto in_gemmm_gemmk_grid_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc); - - return in_gemmm_gemmk_grid_desc; - } - else if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization::Filter1x1Pad0) - { - const auto in_n_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); - - const auto in_n_wo_c_grid_desc = transform_tensor_descriptor( - in_n_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor( - in_n_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmm_gemmk_grid_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc); - - return in_gemmm_gemmk_grid_desc; - } - else - { - const index_t X = b_g_k_c_xs_lengths[3]; - const index_t ConvDilationW = conv_filter_dilations[0]; - const index_t InLeftPadW = input_left_pads[0]; - const index_t InRightPadW = input_right_pads[0]; - - const auto in_n_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); - - const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( - in_n_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_wip_c_grid_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - const auto in_gemmmraw_gemmk_grid_desc = - transform_tensor_descriptor(in_n_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), - make_merge_transform(make_tuple(X, C))), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmm_gemmk_grid_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc); - - return in_gemmm_gemmk_grid_desc; - } - } - - template , - bool>::type = false> - static auto - MakeAGridDescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& /* a_g_n_c_wis_strides */, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& e_g_n_k_wos_lengths, - const std::array& /* e_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) - { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Hi = a_g_n_c_wis_lengths[3]; - const index_t Wi = a_g_n_c_wis_lengths[4]; - - const index_t Ho = e_g_n_k_wos_lengths[3]; - const index_t Wo = e_g_n_k_wos_lengths[4]; - - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; - - if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) - { - const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3, - e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, - index_t{1}, - std::multiplies()); - - const auto in_gemmmraw_gemmkraw_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(NHoWo, C)); - - const auto in_gemmm_gemmk_grid_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc); - - return in_gemmm_gemmk_grid_desc; - } - else if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization::Filter1x1Pad0) - { - const auto in_n_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); - - const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_gemmmraw_gemmk_grid_desc = - transform_tensor_descriptor(in_n_ho_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmm_gemmk_grid_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc); - - return in_gemmm_gemmk_grid_desc; - } - else - { - const index_t Y = b_g_k_c_xs_lengths[3]; - const index_t X = b_g_k_c_xs_lengths[4]; - - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; - - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; - - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; - - const auto in_n_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); - - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmmraw_gemmk_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_merge_transform(make_tuple(Y, X, C))), - make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmm_gemmk_grid_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc); - - return in_gemmm_gemmk_grid_desc; - } - } - - template , - bool>::type = false> - static auto - MakeAGridDescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& /* a_g_n_c_wis_strides */, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& e_g_n_k_wos_lengths, - const std::array& /* e_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) - { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Di = a_g_n_c_wis_lengths[3]; - const index_t Hi = a_g_n_c_wis_lengths[4]; - const index_t Wi = a_g_n_c_wis_lengths[5]; - - const index_t Do = e_g_n_k_wos_lengths[3]; - const index_t Ho = e_g_n_k_wos_lengths[4]; - const index_t Wo = e_g_n_k_wos_lengths[5]; - - const index_t ConvStrideD = conv_filter_strides[0]; - const index_t ConvStrideH = conv_filter_strides[1]; - const index_t ConvStrideW = conv_filter_strides[2]; - - if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) - { - const index_t NDoHoWo = - N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3, - e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, - index_t{1}, - std::multiplies()); - - const auto in_gemmmraw_gemmkraw_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo, C)); - - const auto in_gemmm_gemmk_grid_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc); - - return in_gemmm_gemmk_grid_desc; - } - else if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization::Filter1x1Pad0) - { - const auto in_n_di_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); - - const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor( - in_n_do_ho_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmm_gemmk_grid_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc); - - return in_gemmm_gemmk_grid_desc; - } - else - { - const index_t Z = b_g_k_c_xs_lengths[3]; - const index_t Y = b_g_k_c_xs_lengths[4]; - const index_t X = b_g_k_c_xs_lengths[5]; - - const index_t ConvDilationD = conv_filter_dilations[0]; - const index_t ConvDilationH = conv_filter_dilations[1]; - const index_t ConvDilationW = conv_filter_dilations[2]; - - const index_t InLeftPadD = input_left_pads[0]; - const index_t InLeftPadH = input_left_pads[1]; - const index_t InLeftPadW = input_left_pads[2]; - - const index_t InRightPadD = input_right_pads[0]; - const index_t InRightPadH = input_right_pads[1]; - const index_t InRightPadW = input_right_pads[2]; - - const auto in_n_di_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); - - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Di, InLeftPadD, InRightPadD), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, - Sequence<1, 2>{}, - Sequence<3, 4>{}, - Sequence<5, 6>{}, - Sequence<7>{})); - - const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor( - in_n_z_do_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_merge_transform(make_tuple(Z, Y, X, C))), - make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmm_gemmk_grid_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc); - - return in_gemmm_gemmk_grid_desc; - } - } - - // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as - // properties - template || - is_same_v), - bool>::type = false> + template static auto MakeAGridDescriptor_M_K(const std::array& a_g_n_c_wis_lengths, const std::array& a_g_n_c_wis_strides, const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, + const std::array& b_g_k_c_xs_strides, const std::array& e_g_n_k_wos_lengths, - const std::array& /* e_g_n_k_wos_strides */, + const std::array& e_g_n_k_wos_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, const std::array& input_right_pads) { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); - const index_t Wi = a_g_n_c_wis_lengths[3]; + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); - const index_t Wo = e_g_n_k_wos_lengths[3]; - - const index_t ConvStrideW = conv_filter_strides[0]; - - if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) - { - const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3, - e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, - index_t{1}, - std::multiplies()); - - // This is different - const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; - const auto CStride = I1; - - const auto in_gemmmraw_gemmk_grid_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride)); - - const auto in_gemmm_gemmk_grid_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc); - - return in_gemmm_gemmk_grid_desc; - } - else if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization::Filter1x1Pad0) - { - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t WiStride = a_g_n_c_wis_strides[3]; - const auto CStride = I1; - - const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor( - make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); - - const auto in_n_wo_c_grid_desc = transform_tensor_descriptor( - in_n_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor( - in_n_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmm_gemmk_grid_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc); - - return in_gemmm_gemmk_grid_desc; - } - else - { - const index_t X = b_g_k_c_xs_lengths[3]; - const index_t ConvDilationW = conv_filter_dilations[0]; - const index_t InLeftPadW = input_left_pads[0]; - const index_t InRightPadW = input_right_pads[0]; - - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t WiStride = a_g_n_c_wis_strides[3]; - const auto CStride = I1; - - const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor( - make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); - - const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( - in_n_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_wip_c_grid_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - const auto in_gemmmraw_gemmk_grid_desc = - transform_tensor_descriptor(in_n_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), - make_merge_transform(make_tuple(X, C))), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmm_gemmk_grid_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc); - - return in_gemmm_gemmk_grid_desc; - } + return in_gemmm_gemmk_desc; } - template || - is_same_v), - bool>::type = false> - static auto - MakeAGridDescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& e_g_n_k_wos_lengths, - const std::array& /* e_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) - { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Hi = a_g_n_c_wis_lengths[3]; - const index_t Wi = a_g_n_c_wis_lengths[4]; - - const index_t Ho = e_g_n_k_wos_lengths[3]; - const index_t Wo = e_g_n_k_wos_lengths[4]; - - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; - - if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) - { - const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3, - e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, - index_t{1}, - std::multiplies()); - - // This is different - const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; - const auto CStride = I1; - - const auto in_gemmmraw_gemmkraw_grid_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride)); - - const auto in_gemmm_gemmk_grid_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc); - - return in_gemmm_gemmk_grid_desc; - } - else if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization::Filter1x1Pad0) - { - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t HiStride = a_g_n_c_wis_strides[3]; - const index_t WiStride = a_g_n_c_wis_strides[4]; - const auto CStride = I1; - - const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor( - make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); - - const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_gemmmraw_gemmk_grid_desc = - transform_tensor_descriptor(in_n_ho_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmm_gemmk_grid_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc); - - return in_gemmm_gemmk_grid_desc; - } - else - { - const index_t Y = b_g_k_c_xs_lengths[3]; - const index_t X = b_g_k_c_xs_lengths[4]; - - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; - - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; - - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; - - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t HiStride = a_g_n_c_wis_strides[3]; - const index_t WiStride = a_g_n_c_wis_strides[4]; - const auto CStride = I1; - - const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor( - make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); - - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmmraw_gemmk_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_merge_transform(make_tuple(Y, X, C))), - make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmm_gemmk_grid_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc); - - return in_gemmm_gemmk_grid_desc; - } - } - - template || - is_same_v), - bool>::type = false> - static auto - MakeAGridDescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& e_g_n_k_wos_lengths, - const std::array& /* e_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) - { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Di = a_g_n_c_wis_lengths[3]; - const index_t Hi = a_g_n_c_wis_lengths[4]; - const index_t Wi = a_g_n_c_wis_lengths[5]; - - const index_t Do = e_g_n_k_wos_lengths[3]; - const index_t Ho = e_g_n_k_wos_lengths[4]; - const index_t Wo = e_g_n_k_wos_lengths[5]; - - const index_t ConvStrideD = conv_filter_strides[0]; - const index_t ConvStrideH = conv_filter_strides[1]; - const index_t ConvStrideW = conv_filter_strides[2]; - - if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) - { - const index_t NDoHoWo = - N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3, - e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, - index_t{1}, - std::multiplies()); - - // This is different - const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; - const auto CStride = I1; - - const auto in_gemmmraw_gemmkraw_grid_desc = - make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), make_tuple(WiStride, CStride)); - - const auto in_gemmm_gemmk_grid_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc); - - return in_gemmm_gemmk_grid_desc; - } - else if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization::Filter1x1Pad0) - { - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t DiStride = a_g_n_c_wis_strides[3]; - const index_t HiStride = a_g_n_c_wis_strides[4]; - const index_t WiStride = a_g_n_c_wis_strides[5]; - const auto CStride = I1; - - const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi, C), - make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); - - const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor( - in_n_do_ho_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmm_gemmk_grid_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc); - - return in_gemmm_gemmk_grid_desc; - } - else - { - const index_t Z = b_g_k_c_xs_lengths[3]; - const index_t Y = b_g_k_c_xs_lengths[4]; - const index_t X = b_g_k_c_xs_lengths[5]; - - const index_t ConvDilationD = conv_filter_dilations[0]; - const index_t ConvDilationH = conv_filter_dilations[1]; - const index_t ConvDilationW = conv_filter_dilations[2]; - - const index_t InLeftPadD = input_left_pads[0]; - const index_t InLeftPadH = input_left_pads[1]; - const index_t InLeftPadW = input_left_pads[2]; - - const index_t InRightPadD = input_right_pads[0]; - const index_t InRightPadH = input_right_pads[1]; - const index_t InRightPadW = input_right_pads[2]; - - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t DiStride = a_g_n_c_wis_strides[3]; - const index_t HiStride = a_g_n_c_wis_strides[4]; - const index_t WiStride = a_g_n_c_wis_strides[5]; - const auto CStride = I1; - - const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi, C), - make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); - - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Di, InLeftPadD, InRightPadD), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, - Sequence<1, 2>{}, - Sequence<3, 4>{}, - Sequence<5, 6>{}, - Sequence<7>{})); - - const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor( - in_n_z_do_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_merge_transform(make_tuple(Z, Y, X, C))), - make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmm_gemmk_grid_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc); - - return in_gemmm_gemmk_grid_desc; - } - } - - template || - is_same_v || - is_same_v, - bool>::type = false> - static auto - MakeBGridDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */) - { - const index_t K = b_g_k_c_xs_lengths[1]; - const index_t C = b_g_k_c_xs_lengths[2]; - - const index_t YX = std::accumulate(b_g_k_c_xs_lengths.begin() + 3, - b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial, - index_t{1}, - std::multiplies()); - - const auto wei_k_yxc_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(K, YX * C)); - - const auto wei_gemmn_gemmk_grid_desc = - matrix_padder.PadBDescriptor_N_K(wei_k_yxc_grid_desc); - - return wei_gemmn_gemmk_grid_desc; - } - - template || - is_same_v || - is_same_v || - is_same_v || - is_same_v || - is_same_v, - bool>::type = false> + template static auto MakeBGridDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, const std::array& b_g_k_c_xs_strides) { - const index_t K = b_g_k_c_xs_lengths[1]; - const index_t C = b_g_k_c_xs_lengths[2]; + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides); - const index_t YX = std::accumulate(b_g_k_c_xs_lengths.begin() + 3, - b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial, - index_t{1}, - std::multiplies()); + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); - const index_t KStride = b_g_k_c_xs_strides[1]; - const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial]; - const auto CStride = I1; - - const auto wei_k_yx_c_grid_desc = make_naive_tensor_descriptor( - make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride)); - - const auto wei_gemmnraw_gemmkraw_grid_desc = transform_tensor_descriptor( - wei_k_yx_c_grid_desc, - make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto wei_gemmn_gemmk_grid_desc = - matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_grid_desc); - - return wei_gemmn_gemmk_grid_desc; + return wei_gemmn_gemmk_desc; } - template || - is_same_v || - is_same_v, - bool>::type = false> - static auto - MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, - const std::array& /* e_g_n_k_wos_strides */) - { - const index_t N = e_g_n_k_wos_lengths[1]; - const index_t K = e_g_n_k_wos_lengths[2]; - - const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3, - e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, - index_t{1}, - std::multiplies()); - - const auto out_gemmmraw_gemmnraw_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K)); - - const auto out_gemmm_gemmn_grid_desc = - matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_grid_desc); - - return out_gemmm_gemmn_grid_desc; - } - - template || - is_same_v || - is_same_v || - is_same_v || - is_same_v || - is_same_v, - bool>::type = false> + template static auto MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, const std::array& e_g_n_k_wos_strides) { - const index_t N = e_g_n_k_wos_lengths[1]; - const index_t K = e_g_n_k_wos_lengths[2]; + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides); - const auto KStride = I1; - const index_t WoStride = e_g_n_k_wos_strides[NDimSpatial + 2]; + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); - const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3, - e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, - index_t{1}, - std::multiplies()); - - const auto out_gemmmraw_gemmnraw_grid_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride)); - - const auto out_gemmm_gemmn_grid_desc = - matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_grid_desc); - - return out_gemmm_gemmn_grid_desc; + return out_gemmm_gemmn_desc; } static auto MakeDsGridDescriptor_M_N( diff --git a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp index 9da1297fc3..a872dd5bd4 100644 --- a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp +++ b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp @@ -12,70 +12,45 @@ namespace ck { namespace tensor_operation { namespace device { -// For padding tensors without batch dimension -template = false> +template + typename DoPads> // Sequence __host__ __device__ constexpr auto -PadTensorDescriptor(const TensorDesc_MRaw_NRaw& tensor_desc_mraw_nraw, - MPerBlockType MPerBlock, - NPerBlockType NPerBlock) +PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoPads) { - const auto MRaw = tensor_desc_mraw_nraw.GetLength(Number<0>{}); - const auto NRaw = tensor_desc_mraw_nraw.GetLength(Number<1>{}); + constexpr index_t num_dim = DoPads::Size(); - const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; - const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + static_assert(num_dim == TileLengths::Size() && num_dim == TensorDesc::GetNumOfDimension(), + "wrong! inconsistent # of dimensions"); - const auto MPad = M - MRaw; - const auto NPad = N - NRaw; + // transforms + const auto transforms = generate_tuple( + [&](auto idim) { + const auto MRaw = desc.GetLength(idim); - const auto MTransform = conditional_expr(make_right_pad_transform(MRaw, MPad), - make_pass_through_transform(MRaw)); - const auto NTransform = conditional_expr(make_right_pad_transform(NRaw, NPad), - make_pass_through_transform(NRaw)); + const auto MPerTile = tile_lengths[idim]; - return transform_tensor_descriptor(tensor_desc_mraw_nraw, - make_tuple(MTransform, NTransform), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); -} + const auto M = math::integer_divide_ceil(MRaw, MPerTile) * MPerTile; -// For padding tensors with batch dimension -template = false> -__host__ __device__ constexpr auto -PadTensorDescriptor(const TensorDesc_GRaw_MRaw_NRaw& tensor_desc_graw_mraw_nraw, - MPerBlockType MPerBlock, - NPerBlockType NPerBlock) -{ - const auto GRaw = tensor_desc_graw_mraw_nraw.GetLength(Number<0>{}); - const auto MRaw = tensor_desc_graw_mraw_nraw.GetLength(Number<1>{}); - const auto NRaw = tensor_desc_graw_mraw_nraw.GetLength(Number<2>{}); + const auto MPad = M - MRaw; - const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; - const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const bool DoPadM = DoPads::At(idim); - const auto MPad = M - MRaw; - const auto NPad = N - NRaw; + const auto MTransform = conditional_expr(make_right_pad_transform(MRaw, MPad), + make_pass_through_transform(MRaw)); - const auto MTransform = conditional_expr(make_right_pad_transform(MRaw, MPad), - make_pass_through_transform(MRaw)); - const auto NTransform = conditional_expr(make_right_pad_transform(NRaw, NPad), - make_pass_through_transform(NRaw)); + return MTransform; + }, + Number{}); - return transform_tensor_descriptor( - tensor_desc_graw_mraw_nraw, - make_tuple(make_pass_through_transform(GRaw), MTransform, NTransform), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + // lower dimension Id + const auto lower_dimss = + generate_tuple([&](auto idim) { return Sequence{}; }, Number{}); + + // upper dimension Id + const auto upper_dimss = lower_dimss; + + return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss); } // M/N/K/OPerTileType could be index_t or Number<> @@ -113,7 +88,8 @@ struct GemmGemmPadder __host__ __device__ constexpr auto PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const { - return PadTensorDescriptor(a_desc_mraw_kraw, MPerTile_, KPerTile_); + return PadTensorDescriptor( + a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence{}); } // B[K, N] @@ -121,7 +97,8 @@ struct GemmGemmPadder __host__ __device__ constexpr auto PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const { - return PadTensorDescriptor(b_desc_nraw_kraw, NPerTile_, KPerTile_); + return PadTensorDescriptor( + b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence{}); } // B1[Gemm1N, Gemm1K] = B1[O, N] @@ -129,7 +106,8 @@ struct GemmGemmPadder __host__ __device__ constexpr auto PadB1Descriptor_N_K(const B1Desc_NRaw_KRaw& b1_desc_nraw_kraw) const { - return PadTensorDescriptor(b1_desc_nraw_kraw, OPerTile_, NPerTile_); + return PadTensorDescriptor( + b1_desc_nraw_kraw, make_tuple(OPerTile_, NPerTile_), Sequence{}); } // C[M, Gemm1N] = C[M, O] @@ -137,7 +115,8 @@ struct GemmGemmPadder __host__ __device__ constexpr auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const { - return PadTensorDescriptor(c_desc_mraw_nraw, MPerTile_, OPerTile_); + return PadTensorDescriptor( + c_desc_mraw_nraw, make_tuple(MPerTile_, OPerTile_), Sequence{}); } MPerTileType MPerTile_; @@ -167,21 +146,24 @@ struct GemmPadder __host__ __device__ constexpr auto PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const { - return PadTensorDescriptor(a_desc_mraw_kraw, MPerTile_, KPerTile_); + return PadTensorDescriptor( + a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence{}); } template __host__ __device__ constexpr auto PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const { - return PadTensorDescriptor(b_desc_nraw_kraw, NPerTile_, KPerTile_); + return PadTensorDescriptor( + b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence{}); } template __host__ __device__ constexpr auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const { - return PadTensorDescriptor(c_desc_mraw_nraw, MPerTile_, NPerTile_); + return PadTensorDescriptor( + c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence{}); } MPerTileType MPerTile_; @@ -198,6 +180,44 @@ struct MatrixPadder : public GemmPadder +template +struct GemmPadder_v2 +{ + template + __host__ __device__ constexpr auto + PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const + { + return PadTensorDescriptor( + a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence{}); + } + + template + __host__ __device__ constexpr auto + PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const + { + return PadTensorDescriptor( + b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence{}); + } + + template + __host__ __device__ constexpr auto + PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const + { + return PadTensorDescriptor( + c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence{}); + } + + MPerTileType MPerTile_; + NPerTileType NPerTile_; + KPerTileType KPerTile_; +}; + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp new file mode 100644 index 0000000000..37a6e362c4 --- /dev/null +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -0,0 +1,870 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" + +namespace ck { +namespace tensor_operation { + +template +struct TransformConvFwdToGemm +{ + static constexpr auto I1 = Number<1>{}; + + template , + bool>::type = false> + static auto + MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, + const std::array& /* a_g_n_c_wis_strides */, + const std::array& b_g_k_c_xs_lengths, + const std::array& /* b_g_k_c_xs_strides */, + const std::array& c_g_n_k_wos_lengths, + const std::array& /* c_g_n_k_wos_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const index_t N = a_g_n_c_wis_lengths[1]; + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Wi = a_g_n_c_wis_lengths[3]; + + const index_t Wo = c_g_n_k_wos_lengths[3]; + + const index_t ConvStrideW = conv_filter_strides[0]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, + c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, + index_t{1}, + std::multiplies()); + + const auto in_gemmm_gemmk_desc = + make_naive_tensor_descriptor_packed(make_tuple(NWo, C)); + + return in_gemmm_gemmk_desc; + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + else + { + const index_t X = b_g_k_c_xs_lengths[3]; + const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + + const auto in_n_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmm_gemmk_desc = + transform_tensor_descriptor(in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), + make_merge_transform(make_tuple(X, C))), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + } + + template , + bool>::type = false> + static auto + MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, + const std::array& /* a_g_n_c_wis_strides */, + const std::array& b_g_k_c_xs_lengths, + const std::array& /* b_g_k_c_xs_strides */, + const std::array& c_g_n_k_wos_lengths, + const std::array& /* c_g_n_k_wos_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const index_t N = a_g_n_c_wis_lengths[1]; + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Hi = a_g_n_c_wis_lengths[3]; + const index_t Wi = a_g_n_c_wis_lengths[4]; + + const index_t Ho = c_g_n_k_wos_lengths[3]; + const index_t Wo = c_g_n_k_wos_lengths[4]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, + c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, + index_t{1}, + std::multiplies()); + + const auto in_gemmm_gemmk_desc = + make_naive_tensor_descriptor_packed(make_tuple(NHoWo, C)); + + return in_gemmm_gemmk_desc; + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + const auto in_n_hi_wi_c_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmm_gemmk_desc = + transform_tensor_descriptor(in_n_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + else + { + const index_t Y = b_g_k_c_xs_lengths[3]; + const index_t X = b_g_k_c_xs_lengths[4]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const auto in_n_hi_wi_c_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmk_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_merge_transform(make_tuple(Y, X, C))), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + } + + template , + bool>::type = false> + static auto + MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, + const std::array& /* a_g_n_c_wis_strides */, + const std::array& b_g_k_c_xs_lengths, + const std::array& /* b_g_k_c_xs_strides */, + const std::array& c_g_n_k_wos_lengths, + const std::array& /* c_g_n_k_wos_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const index_t N = a_g_n_c_wis_lengths[1]; + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Di = a_g_n_c_wis_lengths[3]; + const index_t Hi = a_g_n_c_wis_lengths[4]; + const index_t Wi = a_g_n_c_wis_lengths[5]; + + const index_t Do = c_g_n_k_wos_lengths[3]; + const index_t Ho = c_g_n_k_wos_lengths[4]; + const index_t Wo = c_g_n_k_wos_lengths[5]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NDoHoWo = + N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, + c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, + index_t{1}, + std::multiplies()); + + const auto in_gemmm_gemmk_desc = + make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo, C)); + + return in_gemmm_gemmk_desc; + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + const auto in_n_di_hi_wi_c_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + else + { + const index_t Z = b_g_k_c_xs_lengths[3]; + const index_t Y = b_g_k_c_xs_lengths[4]; + const index_t X = b_g_k_c_xs_lengths[5]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const auto in_n_di_hi_wi_c_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_merge_transform(make_tuple(Z, Y, X, C))), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + } + + // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as + // properties + template || + is_same_v), + bool>::type = false> + static auto + MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& /* b_g_k_c_xs_strides */, + const std::array& c_g_n_k_wos_lengths, + const std::array& /* c_g_n_k_wos_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const index_t N = a_g_n_c_wis_lengths[1]; + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Wi = a_g_n_c_wis_lengths[3]; + + const index_t Wo = c_g_n_k_wos_lengths[3]; + + const index_t ConvStrideW = conv_filter_strides[0]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, + c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, + index_t{1}, + std::multiplies()); + + // This is different + const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; + const auto CStride = I1; + + const auto in_gemmm_gemmk_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride)); + + return in_gemmm_gemmk_desc; + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t WiStride = a_g_n_c_wis_strides[3]; + const auto CStride = I1; + + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); + + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + else + { + const index_t X = b_g_k_c_xs_lengths[3]; + const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t WiStride = a_g_n_c_wis_strides[3]; + const auto CStride = I1; + + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmm_gemmk_desc = + transform_tensor_descriptor(in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), + make_merge_transform(make_tuple(X, C))), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + } + + template || + is_same_v), + bool>::type = false> + static auto + MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& /* b_g_k_c_xs_strides */, + const std::array& c_g_n_k_wos_lengths, + const std::array& /* c_g_n_k_wos_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const index_t N = a_g_n_c_wis_lengths[1]; + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Hi = a_g_n_c_wis_lengths[3]; + const index_t Wi = a_g_n_c_wis_lengths[4]; + + const index_t Ho = c_g_n_k_wos_lengths[3]; + const index_t Wo = c_g_n_k_wos_lengths[4]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, + c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, + index_t{1}, + std::multiplies()); + + // This is different + const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; + const auto CStride = I1; + + const auto in_gemmm_gemmk_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride)); + + return in_gemmm_gemmk_desc; + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t HiStride = a_g_n_c_wis_strides[3]; + const index_t WiStride = a_g_n_c_wis_strides[4]; + const auto CStride = I1; + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); + + const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmm_gemmk_desc = + transform_tensor_descriptor(in_n_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + else + { + const index_t Y = b_g_k_c_xs_lengths[3]; + const index_t X = b_g_k_c_xs_lengths[4]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t HiStride = a_g_n_c_wis_strides[3]; + const index_t WiStride = a_g_n_c_wis_strides[4]; + const auto CStride = I1; + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmk_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_merge_transform(make_tuple(Y, X, C))), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + } + + template || + is_same_v), + bool>::type = false> + static auto + MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& /* b_g_k_c_xs_strides */, + const std::array& c_g_n_k_wos_lengths, + const std::array& /* c_g_n_k_wos_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const index_t N = a_g_n_c_wis_lengths[1]; + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Di = a_g_n_c_wis_lengths[3]; + const index_t Hi = a_g_n_c_wis_lengths[4]; + const index_t Wi = a_g_n_c_wis_lengths[5]; + + const index_t Do = c_g_n_k_wos_lengths[3]; + const index_t Ho = c_g_n_k_wos_lengths[4]; + const index_t Wo = c_g_n_k_wos_lengths[5]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NDoHoWo = + N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, + c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, + index_t{1}, + std::multiplies()); + + // This is different + const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; + const auto CStride = I1; + + const auto in_gemmm_gemmk_desc = + make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), make_tuple(WiStride, CStride)); + + return in_gemmm_gemmk_desc; + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t DiStride = a_g_n_c_wis_strides[3]; + const index_t HiStride = a_g_n_c_wis_strides[4]; + const index_t WiStride = a_g_n_c_wis_strides[5]; + const auto CStride = I1; + + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + else + { + const index_t Z = b_g_k_c_xs_lengths[3]; + const index_t Y = b_g_k_c_xs_lengths[4]; + const index_t X = b_g_k_c_xs_lengths[5]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t DiStride = a_g_n_c_wis_strides[3]; + const index_t HiStride = a_g_n_c_wis_strides[4]; + const index_t WiStride = a_g_n_c_wis_strides[5]; + const auto CStride = I1; + + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_merge_transform(make_tuple(Z, Y, X, C))), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + } + + template || + is_same_v || + is_same_v, + bool>::type = false> + static auto + MakeBDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, + const std::array& /* b_g_k_c_xs_strides */) + { + const index_t K = b_g_k_c_xs_lengths[1]; + const index_t C = b_g_k_c_xs_lengths[2]; + + const index_t YX = std::accumulate(b_g_k_c_xs_lengths.begin() + 3, + b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial, + index_t{1}, + std::multiplies()); + + const auto wei_gemmn_gemmk_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, YX * C)); + + return wei_gemmn_gemmk_desc; + } + + template < + typename BLayout, + typename std::enable_if || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v, + bool>::type = false> + static auto MakeBDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides) + { + const index_t K = b_g_k_c_xs_lengths[1]; + const index_t C = b_g_k_c_xs_lengths[2]; + + const index_t YX = std::accumulate(b_g_k_c_xs_lengths.begin() + 3, + b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial, + index_t{1}, + std::multiplies()); + + const index_t KStride = b_g_k_c_xs_strides[1]; + const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial]; + const auto CStride = I1; + + const auto wei_k_yx_c_desc = make_naive_tensor_descriptor( + make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride)); + + const auto wei_gemmn_gemmk_desc = transform_tensor_descriptor( + wei_k_yx_c_desc, + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return wei_gemmn_gemmk_desc; + } + + template || + is_same_v || + is_same_v, + bool>::type = false> + static auto + MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, + const std::array& /* c_g_n_k_wos_strides */) + { + const index_t N = c_g_n_k_wos_lengths[1]; + const index_t K = c_g_n_k_wos_lengths[2]; + + const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, + c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, + index_t{1}, + std::multiplies()); + + const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K)); + + return out_gemmm_gemmn_desc; + } + + template < + typename CLayout, + typename std::enable_if || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v, + bool>::type = false> + static auto MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, + const std::array& c_g_n_k_wos_strides) + { + const index_t N = c_g_n_k_wos_lengths[1]; + const index_t K = c_g_n_k_wos_lengths[2]; + + const auto KStride = I1; + const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2]; + + const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, + c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, + index_t{1}, + std::multiplies()); + + const auto out_gemmm_gemmn_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride)); + + return out_gemmm_gemmn_desc; + } +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/utility/convolution_parameter.hpp b/library/include/ck/library/utility/convolution_parameter.hpp index 5f37e03e15..1c80e392fd 100644 --- a/library/include/ck/library/utility/convolution_parameter.hpp +++ b/library/include/ck/library/utility/convolution_parameter.hpp @@ -49,30 +49,47 @@ struct ConvParam std::size_t GetFlops() const; - template - std::size_t GetByte() const + template + std::size_t GetInputByte() const { // sizeof(InDataType) * (G * N * C * ) + - // sizeof(WeiDataType) * (G * K * C * ) + - // sizeof(OutDataType) * (G * N * K * ); return sizeof(InDataType) * - (G_ * N_ * C_ * - std::accumulate(std::begin(input_spatial_lengths_), - std::begin(input_spatial_lengths_) + num_dim_spatial_, - static_cast(1), - std::multiplies())) + - sizeof(WeiDataType) * - (G_ * K_ * C_ * - std::accumulate(std::begin(filter_spatial_lengths_), - std::begin(filter_spatial_lengths_) + num_dim_spatial_, - static_cast(1), - std::multiplies())) + - sizeof(OutDataType) * (G_ * N_ * K_ * + (G_ * N_ * C_ * + std::accumulate(std::begin(input_spatial_lengths_), + std::begin(input_spatial_lengths_) + num_dim_spatial_, + static_cast(1), + std::multiplies())); + } + + template + std::size_t GetWeightByte() const + { + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * + (G_ * K_ * C_ * + std::accumulate(std::begin(filter_spatial_lengths_), + std::begin(filter_spatial_lengths_) + num_dim_spatial_, + static_cast(1), + std::multiplies())); + } + + template + std::size_t GetOutputByte() const + { + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * (G_ * N_ * K_ * std::accumulate(std::begin(output_spatial_lengths_), std::end(output_spatial_lengths_), static_cast(1), std::multiplies())); } + + template + std::size_t GetByte() const + { + return GetInputByte() + GetWeightByte() + + GetOutputByte(); + } }; std::string get_conv_param_parser_helper_msg();