From 3e8a6dfb9c3a2723281d5d02fb144ace56c6b49e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 20 Aug 2025 14:29:57 +0200 Subject: [PATCH] [CK Tile] Grouped convolution backward data (#2652) * base working version for single groupped conv bwd data * Fix 2d descriptor * fix groups * Add 3d support * fixes * fixes * fixes --------- Co-authored-by: Jakub Piasecki [ROCm/composable_kernel commit: 4212bbc170948292dc826c0f79aebea87b56d3f9] --- .../20_grouped_convolution/CMakeLists.txt | 3 + .../grouped_convolution_backward_data.cpp | 216 ++++ ...n_grouped_convolution_bwd_data_example.inc | 188 +++ include/ck_tile/core/tensor/tensor_view.hpp | 1 + include/ck_tile/host.hpp | 1 + .../reference_grouped_conv_bwd_data.hpp | 227 ++++ include/ck_tile/ops/grouped_convolution.hpp | 2 + ...ouped_convolution_backward_data_kernel.hpp | 985 +++++++++++++++ ...ped_convolution_backward_weight_kernel.hpp | 85 +- .../grouped_convolution_forward_kernel.hpp | 84 +- .../utils/grouped_convolution_utils.hpp | 1 + .../utils/transform_conv_bwd_data_to_gemm.hpp | 1064 +++++++++++++++++ 12 files changed, 2771 insertions(+), 86 deletions(-) create mode 100644 example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp create mode 100644 example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc create mode 100644 include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp create mode 100644 include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp create mode 100644 include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp diff --git a/example/ck_tile/20_grouped_convolution/CMakeLists.txt b/example/ck_tile/20_grouped_convolution/CMakeLists.txt index c05dcac09c..5cb1d2650e 100644 --- a/example/ck_tile/20_grouped_convolution/CMakeLists.txt +++ b/example/ck_tile/20_grouped_convolution/CMakeLists.txt @@ -6,3 +6,6 @@ target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMP add_executable(tile_example_grouped_conv_bwd_weight EXCLUDE_FROM_ALL grouped_convolution_backward_weight.cpp) target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + +add_executable(tile_example_grouped_conv_bwd_data EXCLUDE_FROM_ALL grouped_convolution_backward_data.cpp) +target_compile_options(tile_example_grouped_conv_bwd_data PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp new file mode 100644 index 0000000000..308961de5a --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "grouped_convolution_utils.hpp" + +template , + typename DsLayout = ck_tile::tuple<>, + typename CDEElementWise = ck_tile::element_wise::PassThrough> +float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args, + const ck_tile::stream_config& s) +{ + constexpr int kBlockPerCu = 1; + + constexpr ck_tile::index_t M_Tile = 64; + constexpr ck_tile::index_t N_Tile = 64; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr ck_tile::index_t VectorSizeA = 8; + constexpr ck_tile::index_t VectorSizeB = 8; + constexpr ck_tile::index_t VectorSizeC = 8; + + // Implicit GEMM Traits + using CodegenShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + using GroupedConvTraitsType = + ck_tile::GroupedConvTraits; + using CodegenPipelineProblem = + ck_tile::GemmPipelineProblem; + using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + using ConvEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << CodegenPipeline::GetVectorSizeA() + << ", Vector size B: " << CodegenPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } +} + +#include "run_grouped_convolution_bwd_data_example.inc" + +template +int run_grouped_conv_bwd_data_example_prec_type( + std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[]) +{ + using NWGC = ck_tile::tensor_layout::convolution::NWGC; + using NHWGC = ck_tile::tensor_layout::convolution::NHWGC; + using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC; + + using GKXC = ck_tile::tensor_layout::convolution::GKXC; + using GKYXC = ck_tile::tensor_layout::convolution::GKYXC; + using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC; + + using NWGK = ck_tile::tensor_layout::convolution::NWGK; + using NHWGK = ck_tile::tensor_layout::convolution::NHWGK; + using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK; + + if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK") + { + return run_grouped_conv_bwd_data_example_with_layouts{}, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NWGC{}, GKXC{}, NWGK{}); + } + else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") + { + return run_grouped_conv_bwd_data_example_with_layouts{}, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NHWGC{}, GKYXC{}, NHWGK{}); + } + else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK") + { + return run_grouped_conv_bwd_data_example_with_layouts{}, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{}); + } + else + { + throw std::runtime_error("Unsupported memory layout!"); + } +} + +int run_grouped_conv_bwd_data_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + std::string data_type = arg_parser.get_str("prec"); + std::string in_layout = arg_parser.get_str("in_layout"); + std::string wei_layout = arg_parser.get_str("wei_layout"); + std::string out_layout = arg_parser.get_str("out_layout"); + + if(data_type == "fp16") + { + return run_grouped_conv_bwd_data_example_prec_type( + in_layout, wei_layout, out_layout, argc, argv); + } + else if(data_type == "bf16") + { + return run_grouped_conv_bwd_data_example_prec_type( + in_layout, wei_layout, out_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type for this operation!"); + } +} + +int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_data_example(argc, argv); } diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc new file mode 100644 index 0000000000..3e1c13c833 --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc @@ -0,0 +1,188 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +template +float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args, + int n_warmup, + int n_repeat) +{ + float ave_time = grouped_conv_bwd_data( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::size_t flop = args.GetFlops(); + std::size_t num_byte = args.GetByte(); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; + + return ave_time; +} + +template +int run_grouped_conv_bwd_data_example_with_layouts( + int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using AccDataType = float; + + std::vector filter_spatial_lengths; + std::vector image_spatial_lengths; + std::vector strides; + std::vector dilations; + std::vector lpads; + std::vector rpads; + + const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths, + image_spatial_lengths, + strides, + dilations, + lpads, + rpads, + arg_parser); + + ck_tile::conv::ConvParam conv_param{num_dim_sp, + arg_parser.get_int("g"), + arg_parser.get_int("n"), + arg_parser.get_int("k"), + arg_parser.get_int("c"), + filter_spatial_lengths, + image_spatial_lengths, + strides, + dilations, + lpads, + rpads}; + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + + const auto in_g_n_c_wis_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + const auto wei_g_k_c_xs_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + const auto out_g_n_k_wos_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + ck_tile::HostTensor input(in_g_n_c_wis_desc); + ck_tile::HostTensor weight(wei_g_k_c_xs_desc); + ck_tile::HostTensor output(out_g_n_k_wos_desc); + + if(init_method == 0) + { + ck_tile::FillUniformDistribution{-1.f, 1.f}(weight); + ck_tile::FillUniformDistribution{-1.f, 1.f}(output); + } + else if(init_method == 1) + { + ck_tile::FillMonotonicSeq{}(weight); + ck_tile::FillMonotonicSeq{}(output); + } + else if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(weight); + ck_tile::FillUniformDistribution{1.f, 1.f}(output); + } + else + { + weight.SetZero(); + output.SetZero(); + } + + ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes()); + + input_dev_buf.SetZero(); + weight_dev_buf.ToDevice(weight.data()); + output_dev_buf.ToDevice(output.data()); + + ck_tile::GroupedConvBwdDataHostArgs args(conv_param, + input_dev_buf.GetDeviceBuffer(), + weight_dev_buf.GetDeviceBuffer(), + {}, + output_dev_buf.GetDeviceBuffer(), + kbatch); + + std::cout << "Run Grouped Conv Bwd Data kernel" << std::endl; + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "output: " << output.mDesc << std::endl; + + invoke_grouped_conv_bwd_data(args, n_warmup, n_repeat); + + input_dev_buf.FromDevice(input.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor input_host_ref(in_g_n_c_wis_desc); + input_host_ref.SetZero(); + + ck_tile:: + reference_grouped_conv_bwd_data( + input_host_ref, + weight, + output, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_); + const ck_tile::index_t GemmK = + weight.get_element_size() / (conv_param.G_ * conv_param.K_); + const float max_accumulated_value = + *std::max_element(input_host_ref.mData.begin(), input_host_ref.mData.end()); + const auto rtol_atol = + calculate_rtol_atol( + GemmK, kbatch, max_accumulated_value); + pass = ck_tile::check_err(input, + input_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + throw std::runtime_error("Unsupported gpu verification !!!"); + } + + return pass; +} diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 269465fae6..a85dbc6d00 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -445,6 +445,7 @@ struct null_tensor_view }; template diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index aa5afd25e5..41f5200413 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -27,6 +27,7 @@ #include "ck_tile/host/reference/reference_elementwise.hpp" #include "ck_tile/host/reference/reference_fused_moe.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" +#include "ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp" #include "ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp" #include "ck_tile/host/reference/reference_grouped_conv_fwd.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" diff --git a/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp b/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp new file mode 100644 index 0000000000..c8264800c9 --- /dev/null +++ b/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp @@ -0,0 +1,227 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST void reference_grouped_conv_bwd_data(HostTensor& input, + const HostTensor& weight, + const HostTensor& output, + std::vector conv_strides, + std::vector conv_dilations, + std::vector in_left_pads, + std::vector) +{ + if(!(input.get_num_of_dimension() == NDimSpatial + 3 && + weight.get_num_of_dimension() == NDimSpatial + 3 && + output.get_num_of_dimension() == NDimSpatial + 3)) + { + + printf("%lu %lu %lu", + input.get_num_of_dimension(), + weight.get_num_of_dimension(), + output.get_num_of_dimension()); + + throw std::runtime_error("wrong! inconsistent dimension"); + } + + if constexpr(NDimSpatial == 1) + { + auto func = [&](auto g, auto n, auto c, auto wi) { + std::size_t K = weight.get_lengths()[1]; + std::size_t X = weight.get_lengths()[3]; + + std::size_t Wo = output.get_lengths()[3]; + float v_acc = 0; + + for(std::size_t x = 0; x < X; ++x) + { + auto w_tmp = static_cast(wi) + + static_cast(in_left_pads[0]) - + static_cast(x * conv_dilations[0]); + + if(w_tmp % conv_strides[0] == 0) + { + auto wo = static_cast(w_tmp) / + static_cast(conv_strides[0]); + + if(wo >= 0 && ck_tile::type_convert(wo) < Wo) + { + for(std::size_t k = 0; k < K; ++k) + { + OutDataType v_out = output(g, n, k, wo); + WeiDataType v_wei = weight(g, k, c, x); + v_acc += ck_tile::type_convert(v_out) * + ck_tile::type_convert(v_wei); + } + } + } + } + InDataType v_acc_converted = ck_tile::type_convert(v_acc); + input(g, n, c, wi) = v_acc_converted; + }; + + make_ParallelTensorFunctor(func, + input.get_lengths()[0], + input.get_lengths()[1], + input.get_lengths()[2], + input.get_lengths()[3])(std::thread::hardware_concurrency()); + } + else if constexpr(NDimSpatial == 2) + { + auto func = [&](auto g, auto n, auto c, auto hi, auto wi) { + std::size_t K = weight.get_lengths()[1]; + std::size_t Y = weight.get_lengths()[3]; + std::size_t X = weight.get_lengths()[4]; + + std::size_t Ho = output.get_lengths()[3]; + std::size_t Wo = output.get_lengths()[4]; + + float v_acc = 0; + + for(std::size_t y = 0; y < Y; ++y) + { + auto h_tmp = static_cast(hi) + + static_cast(in_left_pads[0]) - + static_cast(y * conv_dilations[0]); + if(h_tmp % conv_strides[0] == 0) + { + auto ho = static_cast(h_tmp) / + static_cast(conv_strides[0]); + if(ho >= 0 && ck_tile::type_convert(ho) < Ho) + { + for(std::size_t x = 0; x < X; ++x) + { + auto w_tmp = static_cast(wi) + + static_cast(in_left_pads[1]) - + static_cast(x * conv_dilations[1]); + if(w_tmp % conv_strides[1] == 0) + { + auto wo = static_cast(w_tmp) / + static_cast(conv_strides[1]); + + if(wo >= 0 && ck_tile::type_convert(wo) < Wo) + { + for(std::size_t k = 0; k < K; ++k) + { + OutDataType v_out = output(g, n, k, ho, wo); + WeiDataType v_wei = weight(g, k, c, y, x); + v_acc += ck_tile::type_convert(v_out) * + ck_tile::type_convert(v_wei); + } + } + } + } + } + } + } + InDataType v_acc_converted = ck_tile::type_convert(v_acc); + input(g, n, c, hi, wi) = v_acc_converted; + }; + + make_ParallelTensorFunctor(func, + input.get_lengths()[0], + input.get_lengths()[1], + input.get_lengths()[2], + input.get_lengths()[3], + input.get_lengths()[4])(std::thread::hardware_concurrency()); + } + else if constexpr(NDimSpatial == 3) + { + auto func = [&](auto g, auto n, auto c, auto di, auto hi, auto wi) { + std::size_t K = weight.get_lengths()[1]; + std::size_t Z = weight.get_lengths()[3]; + std::size_t Y = weight.get_lengths()[4]; + std::size_t X = weight.get_lengths()[5]; + + std::size_t Do = output.get_lengths()[3]; + std::size_t Ho = output.get_lengths()[4]; + std::size_t Wo = output.get_lengths()[5]; + + float v_acc = 0; + + for(std::size_t z = 0; z < Z; ++z) + { + auto d_tmp = static_cast(di) + + static_cast(in_left_pads[0]) - + static_cast(z * conv_dilations[0]); + if(d_tmp % conv_strides[0] == 0) + { + auto do_ = static_cast(d_tmp) / + static_cast(conv_strides[0]); + if(do_ >= 0 && ck_tile::type_convert(do_) < Do) + { + for(std::size_t y = 0; y < Y; ++y) + { + auto h_tmp = static_cast(hi) + + static_cast(in_left_pads[1]) - + static_cast(y * conv_dilations[1]); + if(h_tmp % conv_strides[1] == 0) + { + auto ho = static_cast(h_tmp) / + static_cast(conv_strides[1]); + if(ho >= 0 && ck_tile::type_convert(ho) < Ho) + { + for(std::size_t x = 0; x < X; ++x) + { + auto w_tmp = + static_cast(wi) + + static_cast(in_left_pads[2]) - + static_cast(x * + conv_dilations[2]); + + if(w_tmp % conv_strides[2] == 0) + { + auto wo = + static_cast(w_tmp) / + static_cast(conv_strides[2]); + if(wo >= 0 && + ck_tile::type_convert(wo) < Wo) + { + for(std::size_t k = 0; k < K; ++k) + { + OutDataType v_out = + output(g, n, k, do_, ho, wo); + WeiDataType v_wei = weight(g, k, c, z, y, x); + v_acc += ck_tile::type_convert(v_out) * + ck_tile::type_convert(v_wei); + } + } + } + } + } + } + } + } + } + } + InDataType v_acc_converted = ck_tile::type_convert(v_acc); + input(g, n, c, di, hi, wi) = v_acc_converted; + }; + + make_ParallelTensorFunctor(func, + input.get_lengths()[0], + input.get_lengths()[1], + input.get_lengths()[2], + input.get_lengths()[3], + input.get_lengths()[4], + input.get_lengths()[5])(std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error( + "Ref_conv_bwd_data: number of dimensions must be between 1 and 3."); + } +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/grouped_convolution.hpp b/include/ck_tile/ops/grouped_convolution.hpp index 29332f941a..09b50f26b0 100644 --- a/include/ck_tile/ops/grouped_convolution.hpp +++ b/include/ck_tile/ops/grouped_convolution.hpp @@ -3,10 +3,12 @@ #pragma once +#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp" #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp" #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp" #include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" #include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" +#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp new file mode 100644 index 0000000000..282a187eae --- /dev/null +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -0,0 +1,985 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/core/utility/env.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp" +#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" + +namespace ck_tile { + +/// @brief The Grouped Convolution kernel device arguments. +template +struct GroupedConvBwdDataKernelArgs +{ + using TilePartitioner = remove_cvref_t; + + using ConvToGemmTransformer = + TransformConvBwdDataToGemm; + static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + + template < + typename InLay = typename GroupedConvTraitsType_::InLayout, + typename WeiLay = typename GroupedConvTraitsType_::WeiLayout, + typename OutLay = typename GroupedConvTraitsType_::OutLayout, + typename std::enable_if && + std::is_same_v && + std::is_same_v, + bool>::type = false> + CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs& args) + { + in_g_n_c_wis_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0])}; + wei_g_k_c_xs_lengths = {static_cast(args.G_), + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0])}; + out_g_n_k_wos_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.K_), + static_cast(args.output_spatial_lengths_[0])}; + + conv_filter_strides = {static_cast(args.conv_filter_strides_[0])}; + conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0])}; + input_left_pads = {static_cast(args.input_left_pads_[0])}; + input_right_pads = {static_cast(args.input_right_pads_[0])}; + + k_batch = args.k_batch; + + in_ptr = args.in_ptr; + wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } + out_ptr = args.out_ptr; + + const index_t X = wei_g_k_c_xs_lengths[3]; + const index_t ConvStrideW = conv_filter_strides[0]; + const index_t ConvDilationW = conv_filter_dilations[0]; + const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW); + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde); + + if(XDotSlice <= 0) + { + continue; + } + + if(gemm_count >= MaxGroupedGemmGroupsNum) + { + gemm_count++; + // Avoid array segfault + continue; + } + + tildes = {i_xtilde}; + + ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, + wei_g_k_c_xs_lengths, + out_g_n_k_wos_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes}; + + auto grid_descs = + conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< + GroupedConvTraitsType_::NDimSpatial>(1); + + a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{}); + b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{}); + c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{}); + + const index_t grid_size_grp = + TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0), + c_grid_descs_m_n[gemm_count].get_length(I1)); + + block_starts[gemm_count] = grid_size_; + block_ends[gemm_count] = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + ++gemm_count; + } + group_stride_a = args.K_; // A: Out NWGK + group_stride_b = args.K_ * args.C_ * + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); // B: Wei GKXC + group_stride_c = args.C_; // C: In NWGC + + GemmBatch = args.G_; + } + + template < + typename InLay = typename GroupedConvTraitsType_::InLayout, + typename WeiLay = typename GroupedConvTraitsType_::WeiLayout, + typename OutLay = typename GroupedConvTraitsType_::OutLayout, + typename std::enable_if && + std::is_same_v && + std::is_same_v, + bool>::type = false> + CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs& args) + { + in_g_n_c_wis_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1])}; + wei_g_k_c_xs_lengths = {static_cast(args.G_), + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1])}; + out_g_n_k_wos_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.K_), + static_cast(args.output_spatial_lengths_[0]), + static_cast(args.output_spatial_lengths_[1])}; + + conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), + static_cast(args.conv_filter_strides_[1])}; + conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), + static_cast(args.conv_filter_dilations_[1])}; + input_left_pads = {static_cast(args.input_left_pads_[0]), + static_cast(args.input_left_pads_[1])}; + input_right_pads = {static_cast(args.input_right_pads_[0]), + static_cast(args.input_right_pads_[1])}; + + k_batch = args.k_batch; + + in_ptr = args.in_ptr; + wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } + out_ptr = args.out_ptr; + + const index_t Y = wei_g_k_c_xs_lengths[3]; + const index_t X = wei_g_k_c_xs_lengths[4]; + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW); + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde); + + if(XDotSlice * YDotSlice <= 0) + { + continue; + } + + if(gemm_count >= MaxGroupedGemmGroupsNum) + { + gemm_count++; + // Avoid array segfault + continue; + } + + tildes = {i_ytilde, i_xtilde}; + + ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, + wei_g_k_c_xs_lengths, + out_g_n_k_wos_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes}; + + auto grid_descs = conv_to_gemm_transformer + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< + GroupedConvTraitsType_::NDimSpatial>(1); + + a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{}); + b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{}); + c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{}); + + const index_t grid_size_grp = + TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0), + c_grid_descs_m_n[gemm_count].get_length(I1)); + + block_starts[gemm_count] = grid_size_; + block_ends[gemm_count] = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + ++gemm_count; + } + } + group_stride_a = args.K_; // A: Out NWGK + group_stride_b = args.K_ * args.C_ * + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); // B: Wei GKXC + group_stride_c = args.C_; // C: In NWGC + + GemmBatch = args.G_; + } + + template < + typename InLay = typename GroupedConvTraitsType_::InLayout, + typename WeiLay = typename GroupedConvTraitsType_::WeiLayout, + typename OutLay = typename GroupedConvTraitsType_::OutLayout, + typename std::enable_if && + std::is_same_v && + std::is_same_v, + bool>::type = false> + CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs& args) + { + in_g_n_c_wis_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1]), + static_cast(args.input_spatial_lengths_[2])}; + wei_g_k_c_xs_lengths = {static_cast(args.G_), + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1]), + static_cast(args.filter_spatial_lengths_[2])}; + out_g_n_k_wos_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.K_), + static_cast(args.output_spatial_lengths_[0]), + static_cast(args.output_spatial_lengths_[1]), + static_cast(args.output_spatial_lengths_[2])}; + + conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), + static_cast(args.conv_filter_strides_[1]), + static_cast(args.conv_filter_strides_[2])}; + conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), + static_cast(args.conv_filter_dilations_[1]), + static_cast(args.conv_filter_dilations_[2])}; + input_left_pads = {static_cast(args.input_left_pads_[0]), + static_cast(args.input_left_pads_[1]), + static_cast(args.input_left_pads_[2])}; + input_right_pads = {static_cast(args.input_right_pads_[0]), + static_cast(args.input_right_pads_[1]), + static_cast(args.input_right_pads_[2])}; + + k_batch = args.k_batch; + + in_ptr = args.in_ptr; + wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } + out_ptr = args.out_ptr; + + const index_t Z = wei_g_k_c_xs_lengths[3]; + const index_t Y = wei_g_k_c_xs_lengths[4]; + const index_t X = wei_g_k_c_xs_lengths[5]; + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + const auto GcdStrideDilationD = gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW); + const auto ZTilde = ConvStrideD / GcdStrideDilationD; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) + { + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + const auto ZDotSlice = integer_divide_ceil(Z - i_ztilde, ZTilde); + const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde); + + if(ZDotSlice * XDotSlice * YDotSlice <= 0) + { + continue; + } + + if(gemm_count >= MaxGroupedGemmGroupsNum) + { + gemm_count++; + // Avoid array segfault + continue; + } + + tildes = {i_ztilde, i_ytilde, i_xtilde}; + + ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, + wei_g_k_c_xs_lengths, + out_g_n_k_wos_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes}; + + auto grid_descs = conv_to_gemm_transformer + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< + GroupedConvTraitsType_::NDimSpatial>(1); + + a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{}); + b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{}); + c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{}); + + const index_t grid_size_grp = + TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0), + c_grid_descs_m_n[gemm_count].get_length(I1)); + + block_starts[gemm_count] = grid_size_; + block_ends[gemm_count] = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + ++gemm_count; + } + } + } + + group_stride_a = args.K_; // A: Out NWGK + group_stride_b = args.K_ * args.C_ * + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); // B: Wei GKXC + group_stride_c = args.C_; // C: In NWGC + + GemmBatch = args.G_; // C: In NWGC + } + + static constexpr index_t MaxGroupedGemmGroupsNum = 128; + + using ABCGridDescs = + remove_cvref_t; + + using AGridDescMK = remove_cvref_t{}])>; + using BGridDescNK = remove_cvref_t{}])>; + using CGridDescMN = remove_cvref_t{}])>; + + static constexpr index_t NonSpatialDims = 3; + array in_g_n_c_wis_lengths; + array wei_g_k_c_xs_lengths; + array out_g_n_k_wos_lengths; + + array conv_filter_strides; + array conv_filter_dilations; + array input_left_pads; + array input_right_pads; + array tildes; + + index_t k_batch; + index_t GemmBatch; + index_t grid_size_ = 0; + index_t gemm_count = 0; + + const void* out_ptr; + void* in_ptr; + std::array ds_ptr; + const void* wei_ptr; + + array a_grid_descs_m_k; + array b_grid_descs_n_k; + array c_grid_descs_m_n; + + array block_starts; + array block_ends; + + long_index_t group_stride_a; + long_index_t group_stride_b; + long_index_t group_stride_c; +}; + +/// @brief The Grouped Convolution Backward Data kernel template. +/// +/// @paragraph Overview Overview +/// This class provides the grouped convolution backward data kernel template. By +/// semantic division of Implicit GEMM algorithm into following parts we achieve +/// flexible, versatile and robust kernel implementation. +/// +/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator() +/// function call operator" which determines the work scope of each workgroup. +/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm. +/// This is the place where each workgroup is loading data from global memory and +/// carrying out dot products. +/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation +/// responsible for storing results to global memory. This is also the place where +/// any additional operator fusion may take place. +/// +/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_ +/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all +/// internal details of those functional parts. You can think of it like both gemm and +/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover +/// the policy is responsible for definition of all necessary data layouts and thread's +/// work distribution. +/// +/// @tparam GroupedConvTraitsType_ The type of class providing traits for grouped convolution. +/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into +/// the +/// output data tile to be calculated. It determines the +/// workgroup to data relationship (or in other words - which +/// data would be processed and calculated by which workgroup). +/// @tparam GemmPipeline_ The type of class which provides the core part of matrix +/// multiplication. This class should provide implementation of +/// data loading from global memory and performing block-wise +/// matrix multiplication. You can think of it as a work done by +/// single workgroup point of view. +/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix +/// multiplication implementation. It is responsible for storing +/// results calculated by @ref GemmPipeline_ "GemmPipeline" to +/// the output C tensor in global memory. +template +struct GroupedConvolutionBackwardDataKernel +{ + static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_; + static constexpr ConvolutionSpecialization ConvSpecialization = + GroupedConvTraitsType_::ConvSpecialization; + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using GemmALayout = remove_cvref_t; + using GemmBLayout = remove_cvref_t; + using GemmCLayout = remove_cvref_t; + + using InLayout = remove_cvref_t; + using WeiLayout = remove_cvref_t; + using OutLayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + + using GemmDsLayout = remove_cvref_t; + static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor; + + static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + + using InDataType = remove_cvref_t; + using WeiDataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + + using OutDataType = remove_cvref_t; + + using GroupedConvBwdDataKernelArgsSpecialized = + GroupedConvBwdDataKernelArgs; + static constexpr index_t MaxGroupedGemmGroupsNum = + GroupedConvBwdDataKernelArgsSpecialized::MaxGroupedGemmGroupsNum; + + // TODO: Enable this + static constexpr bool IsSplitKSupported = false; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto I3 = number<3>(); + + static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK, + "Not supported!"); + static_assert(std::is_same_v, + "Not supported A GEMM layout!"); + static_assert(std::is_same_v, + "Not supported B GEMM layout!"); + static_assert(std::is_same_v, + "Not supported C GEMM layout!"); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "grouped_convolution_backward_data", gemm_prec_str, GemmPipeline::GetName()); + // clang-format on + } + + CK_TILE_HOST static auto GridSize(const GroupedConvBwdDataKernelArgsSpecialized& kargs) + { + // enable batched grouped gemm + return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.k_batch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + + CK_TILE_HOST static constexpr GroupedConvBwdDataKernelArgsSpecialized + MakeKernelArgs(const GroupedConvBwdDataHostArgs& hostArgs) + { + return GroupedConvBwdDataKernelArgsSpecialized(hostArgs); + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_HOST static bool + IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized& kargs) + { + if constexpr((EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value) || + !IsSplitKSupported) + { + if(kargs.k_batch != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + } + return false; + } + } + + if(kargs.gemm_count > MaxGroupedGemmGroupsNum) + { + return false; + } + + const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}]; + const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}]; + + // check ConvSpecialization + if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3]; + const index_t ConvStride = kargs.conv_filter_strides[i]; + const index_t LeftPad = kargs.input_left_pads[i]; + const index_t RightPad = kargs.input_right_pads[i]; + + if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3]; + const index_t LeftPad = kargs.input_left_pads[i]; + const index_t RightPad = kargs.input_right_pads[i]; + + if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3) + { + if(ConvC != 1) + { + return false; + } + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3]; + + if(filter_spatial_dim != I3) + { + return false; + } + } + } + + namespace ctc = tensor_layout::convolution; + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v) + { + // Check access per C + if(ConvC % GemmPipeline::GetVectorSizeB() != 0) + { + CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!"); + return false; + } + } + else + { + CK_TILE_ERROR("Not supported input layout!"); + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(ConvC % EpiloguePipeline::GetVectorSizeC() != 0) + { + CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!"); + return false; + } + } + else + { + CK_TILE_ERROR("Not supported weight layout!"); + return false; + } + + // check vector access of E + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(ConvK % GemmPipeline::GetVectorSizeA() != 0) + { + CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!"); + return false; + } + } + else + { + CK_TILE_ERROR("Not supported output layout!"); + return false; + } + + return true; + } + + template + CK_TILE_DEVICE static auto + MakeGemmTensorViews(const OutDataType* a_ptr, + const InDataType* b_ptr, + const std::array& ds_ptr, + WeiDataType* c_ptr, + const GroupedConvBwdDataKernelArgsSpecialized& kargs, + const index_t group_id) + { + static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); + static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!"); + const auto& a_tensor_view = [&]() { + return make_tensor_view( + a_ptr, + kargs.a_grid_descs_m_k[group_id]); // A: out + }(); + + const auto& b_tensor_view = [&]() { + return make_tensor_view( + b_ptr, + kargs.b_grid_descs_n_k[group_id]); // B: weight + }(); + + const auto& c_tensor_view = [&]() { + return make_tensor_view(c_ptr, + kargs.c_grid_descs_m_n[group_id]); + }(); + + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + static_assert(std::is_same_v, OutLayout>, + "Not supported!"); + static_assert(std::is_same_v, + "Not supported!"); + static_assert(std::is_same_v, OutDataType>, + "Not supported!"); + + return make_tensor_view( + static_cast(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]); + }, + number{}); + + return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) + { + const auto& a_pad_view = [&]() { + const auto& a_tensor_view = views.at(I0); + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + const auto& b_pad_view = [&]() { + const auto& b_tensor_view = views.at(I1); + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + const auto& ds_tensor_view = views.at(I2); + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + }, + number{}); + + const auto& c_pad_view = [&]() { + const auto& c_tensor_view = views.at(I3); + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views, + const index_t i_m, + const index_t i_n, + const index_t i_k = 0) + { + const auto& a_pad_view = views.at(I0); + const auto& b_pad_view = views.at(I1); + const auto& ds_pad_view = views.at(I2); + const auto& c_pad_view = views.at(I3); + + const auto& a_block_window = [&]() { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, i_k}); + }(); + + const auto& b_block_window = [&]() { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, i_k}); + }(); + + const auto ds_block_window = generate_tuple( + [&](auto i) { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, i_n}); + }, + number{}); + + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param kargs Grouped Convolution Backward Data kernel arguments + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr, + const InDataType* b_ptr, + const std::array& ds_ptr, + WeiDataType* c_ptr, + void* smem_ptr_0, + const GroupedConvBwdDataKernelArgsSpecialized& kargs, + const index_t block_idx_m, + const index_t block_idx_n, + const index_t group_id) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum( + gemm_pad_views.at(I0).get_tensor_descriptor().get_length(I1))); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, smem_ptr_0); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The starting pointer of 1st shared memory block. + * @param smem_ptr_1 The starting pointer of 2nd shared memory block. + * @param kargs Grouped Convolution Backward Data kernel arguments + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr, + const InDataType* b_ptr, + const std::array& ds_ptr, + WeiDataType* c_ptr, + void* __restrict__ smem_ptr_0, + void* __restrict__ smem_ptr_1, + const GroupedConvBwdDataKernelArgsSpecialized& kargs, + const index_t block_idx_m, + const index_t block_idx_n, + const index_t group_id) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id); + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(gemm_tile_windows.at(I0).get_length(I1))); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + + CK_TILE_DEVICE index_t FindGroupId(const GroupedConvBwdDataKernelArgsSpecialized& kargs, + index_t block_id) const + { + index_t left = 0; + index_t right = kargs.gemm_count; + index_t group_id = index_t((left + right) >> 1); + + while((!(block_id >= kargs.block_starts[group_id] && + block_id < kargs.block_ends[group_id])) && + left <= right) + { + if(block_id < kargs.block_starts[group_id]) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) >> 1); + } + + return group_id; + } + + CK_TILE_DEVICE void operator()(GroupedConvBwdDataKernelArgsSpecialized kargs) const + { + const auto blockIdX = __builtin_amdgcn_readfirstlane(blockIdx.x); + const index_t group_id = FindGroupId(kargs, blockIdX); + + const auto [iM, iN] = OffsettedTile1DPartitioner::GetOffsetedTileIndex( + kargs.block_starts[group_id], + kargs.c_grid_descs_m_n[group_id].get_length(I0), + kargs.c_grid_descs_m_n[group_id].get_length(I1)); + + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const auto blockIdY = __builtin_amdgcn_readfirstlane(blockIdx.y); + const auto group_offset_a = __builtin_amdgcn_readfirstlane(kargs.group_stride_a * blockIdY); + const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY); + const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY); + + // options + // conv_bwd_data = Out * Weight = In + const OutDataType* a_ptr = static_cast(kargs.out_ptr) + group_offset_a; + const WeiDataType* b_ptr = static_cast(kargs.wei_ptr) + group_offset_b; + InDataType* c_ptr = static_cast(kargs.in_ptr) + group_offset_c; + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + __shared__ char smem_ptr_1[GetSmemSize()]; + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm2LDS(a_ptr, + b_ptr, + kargs.ds_ptr, + c_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + i_m, + i_n, + group_id); + } + } + else + { + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n, group_id); + } + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 7ea2e31706..2700353049 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -17,19 +17,19 @@ namespace ck_tile { /// @brief The Grouped Convolution kernel device arguments. -template +template struct GroupedConvBwdWeightKernelArgs { using ConvToGemmTransformer = - TransformConvBwdWeightToGemm; - static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; + TransformConvBwdWeightToGemm; + static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor; template < - typename InLay = typename GroupedConvTraitsType::InLayout, - typename WeiLay = typename GroupedConvTraitsType::WeiLayout, - typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename InLay = typename GroupedConvTraitsType_::InLayout, + typename WeiLay = typename GroupedConvTraitsType_::WeiLayout, + typename OutLay = typename GroupedConvTraitsType_::OutLayout, typename std::enable_if && std::is_same_v && std::is_same_v, @@ -75,7 +75,7 @@ struct GroupedConvBwdWeightKernelArgs // tuple auto grid_descs = conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< - GroupedConvTraitsType::NDimSpatial>(); + GroupedConvTraitsType_::NDimSpatial>(); a_grid_desc_m_k = grid_descs.at(number<0>{}); b_grid_desc_n_k = grid_descs.at(number<1>{}); @@ -96,9 +96,9 @@ struct GroupedConvBwdWeightKernelArgs } template < - typename InLay = typename GroupedConvTraitsType::InLayout, - typename WeiLay = typename GroupedConvTraitsType::WeiLayout, - typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename InLay = typename GroupedConvTraitsType_::InLayout, + typename WeiLay = typename GroupedConvTraitsType_::WeiLayout, + typename OutLay = typename GroupedConvTraitsType_::OutLayout, typename std::enable_if && std::is_same_v && std::is_same_v, @@ -151,7 +151,7 @@ struct GroupedConvBwdWeightKernelArgs // tuple auto grid_descs = conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< - GroupedConvTraitsType::NDimSpatial>(); + GroupedConvTraitsType_::NDimSpatial>(); a_grid_desc_m_k = grid_descs.at(number<0>{}); b_grid_desc_n_k = grid_descs.at(number<1>{}); @@ -172,9 +172,9 @@ struct GroupedConvBwdWeightKernelArgs } template < - typename InLay = typename GroupedConvTraitsType::InLayout, - typename WeiLay = typename GroupedConvTraitsType::WeiLayout, - typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename InLay = typename GroupedConvTraitsType_::InLayout, + typename WeiLay = typename GroupedConvTraitsType_::WeiLayout, + typename OutLay = typename GroupedConvTraitsType_::OutLayout, typename std::enable_if && std::is_same_v && std::is_same_v, @@ -234,7 +234,7 @@ struct GroupedConvBwdWeightKernelArgs // tuple auto grid_descs = conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< - GroupedConvTraitsType::NDimSpatial>(); + GroupedConvTraitsType_::NDimSpatial>(); a_grid_desc_m_k = grid_descs.at(number<0>{}); b_grid_desc_n_k = grid_descs.at(number<1>{}); @@ -263,14 +263,14 @@ struct GroupedConvBwdWeightKernelArgs using CGridDescMN = remove_cvref_t{}])>; static constexpr index_t NonSpatialDims = 3; - array in_g_n_c_wis_lengths; - array wei_g_k_c_xs_lengths; - array out_g_n_k_wos_lengths; + array in_g_n_c_wis_lengths; + array wei_g_k_c_xs_lengths; + array out_g_n_k_wos_lengths; - array conv_filter_strides; - array conv_filter_dilations; - array input_left_pads; - array input_right_pads; + array conv_filter_strides; + array conv_filter_dilations; + array input_left_pads; + array input_right_pads; index_t k_batch; index_t GemmM; @@ -292,12 +292,12 @@ struct GroupedConvBwdWeightKernelArgs long_index_t group_stride_c; }; -/// @brief The Grouped Convolution Forward kernel template. +/// @brief The Grouped Convolution Backward Weight kernel template. /// /// @paragraph Overview Overview -/// This class provides the grouped convolution forward kernel template. By semantic -/// division of Implicit GEMM algorithm into following parts we achieve flexible, -/// versatile and robust kernel implementation. +/// This class provides the grouped convolution backward weight kernel template. By +/// semantic division of Implicit GEMM algorithm into following parts we achieve +/// flexible, versatile and robust kernel implementation. /// /// @li @b Prolog - The start of GEMM kernel implementation in @ref operator() /// function call operator" which determines the work scope of each workgroup. @@ -315,7 +315,7 @@ struct GroupedConvBwdWeightKernelArgs /// the policy is responsible for definition of all necessary data layouts and thread's /// work distribution. /// -/// tparam ConvSpecialization Tensor descriptors specialization. +/// @tparam GroupedConvTraitsType_ The type of class providing traits for grouped convolution. /// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into /// the /// output data tile to be calculated. It determines the @@ -330,15 +330,15 @@ struct GroupedConvBwdWeightKernelArgs /// multiplication implementation. It is responsible for storing /// results calculated by @ref GemmPipeline_ "GemmPipeline" to /// the output C tensor in global memory. -template struct GroupedConvolutionBackwardWeightKernel { - static constexpr index_t NDimSpatial = GroupedConvTraitsType::NDimSpatial_; + static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_; static constexpr ConvolutionSpecialization ConvSpecialization = - GroupedConvTraitsType::ConvSpecialization; + GroupedConvTraitsType_::ConvSpecialization; using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; @@ -346,13 +346,13 @@ struct GroupedConvolutionBackwardWeightKernel using GemmBLayout = remove_cvref_t; using GemmCLayout = remove_cvref_t; - using InLayout = remove_cvref_t; - using WeiLayout = remove_cvref_t; - using OutLayout = remove_cvref_t; - using DsLayout = remove_cvref_t; + using InLayout = remove_cvref_t; + using WeiLayout = remove_cvref_t; + using OutLayout = remove_cvref_t; + using DsLayout = remove_cvref_t; using GemmDsLayout = remove_cvref_t; - static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; + static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor; static constexpr index_t kBlockSize = GemmPipeline::BlockSize; @@ -363,7 +363,7 @@ struct GroupedConvolutionBackwardWeightKernel using OutDataType = remove_cvref_t; using GroupedConvBwdWeightKernelArgsSpecialized = - GroupedConvBwdWeightKernelArgs; + GroupedConvBwdWeightKernelArgs; // TODO: Enable this static constexpr bool IsSplitKSupported = true; @@ -594,12 +594,9 @@ struct GroupedConvolutionBackwardWeightKernel }(); const auto& c_tensor_view = [&]() { - return make_naive_tensor_view( + return make_tensor_view( c_ptr, - make_tuple(kargs.GemmM, kargs.GemmN), - make_tuple(kargs.GemmN, 1), - number{}, - number<1>{}); + kargs.c_grid_desc_m_n); // B: in }(); const auto& ds_tensor_view = generate_tuple( @@ -708,7 +705,7 @@ struct GroupedConvolutionBackwardWeightKernel * @param b_ptr input B pointer * @param c_ptr output C pointer * @param smem_ptr_0 The start memory pointer of the shared memory block. - * @param kargs Grouped Convolution Forward kernel arguments + * @param kargs Grouped Convolution Backward Weight kernel arguments * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * @@ -758,7 +755,7 @@ struct GroupedConvolutionBackwardWeightKernel * @param c_ptr output C pointer * @param smem_ptr_0 The starting pointer of 1st shared memory block. * @param smem_ptr_1 The starting pointer of 2nd shared memory block. - * @param kargs Grouped Convolution Forward kernel arguments + * @param kargs Grouped Convolution Backward Weight kernel arguments * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index d3a90ea144..d4f4eca0d0 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -17,19 +17,19 @@ namespace ck_tile { /// @brief The Grouped Convolution kernel device arguments. -template +template struct GroupedConvFwdKernelArgs { using ConvToGemmFwdTransformer = - TransformConvFwdToGemm; - static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; + TransformConvFwdToGemm; + static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor; template < - typename InLay = typename GroupedConvTraitsType::InLayout, - typename WeiLay = typename GroupedConvTraitsType::WeiLayout, - typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename InLay = typename GroupedConvTraitsType_::InLayout, + typename WeiLay = typename GroupedConvTraitsType_::WeiLayout, + typename OutLay = typename GroupedConvTraitsType_::OutLayout, typename std::enable_if && std::is_same_v && std::is_same_v, @@ -79,13 +79,13 @@ struct GroupedConvFwdKernelArgs a_grid_desc_m_k = conv_to_gemm_transformer - .template MakeADescriptor_M_K(); + .template MakeADescriptor_M_K(); b_grid_desc_n_k = conv_to_gemm_transformer - .template MakeBDescriptor_N_K(); + .template MakeBDescriptor_N_K(); c_grid_desc_m_n = conv_to_gemm_transformer - .template MakeCDescriptor_M_N(); + .template MakeCDescriptor_M_N(); group_stride_a = args.C_; group_stride_b = args.K_ * args.C_ * @@ -97,9 +97,9 @@ struct GroupedConvFwdKernelArgs } template < - typename InLay = typename GroupedConvTraitsType::InLayout, - typename WeiLay = typename GroupedConvTraitsType::WeiLayout, - typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename InLay = typename GroupedConvTraitsType_::InLayout, + typename WeiLay = typename GroupedConvTraitsType_::WeiLayout, + typename OutLay = typename GroupedConvTraitsType_::OutLayout, typename std::enable_if && std::is_same_v && std::is_same_v, @@ -156,13 +156,13 @@ struct GroupedConvFwdKernelArgs a_grid_desc_m_k = conv_to_gemm_transformer - .template MakeADescriptor_M_K(); + .template MakeADescriptor_M_K(); b_grid_desc_n_k = conv_to_gemm_transformer - .template MakeBDescriptor_N_K(); + .template MakeBDescriptor_N_K(); c_grid_desc_m_n = conv_to_gemm_transformer - .template MakeCDescriptor_M_N(); + .template MakeCDescriptor_M_N(); group_stride_a = args.C_; group_stride_b = args.K_ * args.C_ * @@ -174,9 +174,9 @@ struct GroupedConvFwdKernelArgs } template < - typename InLay = typename GroupedConvTraitsType::InLayout, - typename WeiLay = typename GroupedConvTraitsType::WeiLayout, - typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename InLay = typename GroupedConvTraitsType_::InLayout, + typename WeiLay = typename GroupedConvTraitsType_::WeiLayout, + typename OutLay = typename GroupedConvTraitsType_::OutLayout, typename std::enable_if && std::is_same_v && std::is_same_v, @@ -242,13 +242,13 @@ struct GroupedConvFwdKernelArgs a_grid_desc_m_k = conv_to_gemm_transformer - .template MakeADescriptor_M_K(); + .template MakeADescriptor_M_K(); b_grid_desc_n_k = conv_to_gemm_transformer - .template MakeBDescriptor_N_K(); + .template MakeBDescriptor_N_K(); c_grid_desc_m_n = conv_to_gemm_transformer - .template MakeCDescriptor_M_N(); + .template MakeCDescriptor_M_N(); group_stride_a = args.C_; group_stride_b = args.K_ * args.C_ * @@ -261,23 +261,23 @@ struct GroupedConvFwdKernelArgs using AGridDescMK = remove_cvref_t< decltype(ConvToGemmFwdTransformer{} - .template MakeADescriptor_M_K())>; + .template MakeADescriptor_M_K())>; using BGridDescNK = remove_cvref_t< decltype(ConvToGemmFwdTransformer{} - .template MakeBDescriptor_N_K())>; + .template MakeBDescriptor_N_K())>; using CGridDescMN = remove_cvref_t< decltype(ConvToGemmFwdTransformer{} - .template MakeCDescriptor_M_N())>; + .template MakeCDescriptor_M_N())>; static constexpr index_t NonSpatialDims = 3; - array in_g_n_c_wis_lengths; - array wei_g_k_c_xs_lengths; - array out_g_n_k_wos_lengths; + array in_g_n_c_wis_lengths; + array wei_g_k_c_xs_lengths; + array out_g_n_k_wos_lengths; - array conv_filter_strides; - array conv_filter_dilations; - array input_left_pads; - array input_right_pads; + array conv_filter_strides; + array conv_filter_dilations; + array input_left_pads; + array input_right_pads; index_t k_batch; index_t GemmM; @@ -322,7 +322,7 @@ struct GroupedConvFwdKernelArgs /// the policy is responsible for definition of all necessary data layouts and thread's /// work distribution. /// -/// @tparam GroupedConvTraitsType The type of class providing traits for grouped convolution. +/// @tparam GroupedConvTraitsType_ The type of class providing traits for grouped convolution. /// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into /// the /// output data tile to be calculated. It determines the @@ -337,15 +337,15 @@ struct GroupedConvFwdKernelArgs /// multiplication implementation. It is responsible for storing /// results calculated by @ref GemmPipeline_ "GemmPipeline" to /// the output C tensor in global memory. -template struct GroupedConvolutionForwardKernel { - static constexpr index_t NDimSpatial = GroupedConvTraitsType::NDimSpatial; + static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial; static constexpr ConvolutionSpecialization ConvSpecialization = - GroupedConvTraitsType::ConvSpecialization; + GroupedConvTraitsType_::ConvSpecialization; using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; @@ -353,13 +353,13 @@ struct GroupedConvolutionForwardKernel using GemmBLayout = remove_cvref_t; using GemmCLayout = remove_cvref_t; - using InLayout = remove_cvref_t; - using WeiLayout = remove_cvref_t; - using OutLayout = remove_cvref_t; - using DsLayout = remove_cvref_t; + using InLayout = remove_cvref_t; + using WeiLayout = remove_cvref_t; + using OutLayout = remove_cvref_t; + using DsLayout = remove_cvref_t; using GemmDsLayout = remove_cvref_t; - static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; + static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor; static constexpr index_t kBlockSize = GemmPipeline::BlockSize; @@ -369,7 +369,7 @@ struct GroupedConvolutionForwardKernel // Below type is actually accumulation data type - the output of block GEMM. using OutDataType = remove_cvref_t; - using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs; + using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs; // TODO: Enable this static constexpr bool IsSplitKSupported = false; diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index b173ab25a1..3e5e87a975 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -42,6 +42,7 @@ struct GroupedConvHostArgs : public conv::ConvParam using GroupedConvFwdHostArgs = GroupedConvHostArgs; using GroupedConvBwdWeightHostArgs = GroupedConvHostArgs; +using GroupedConvBwdDataHostArgs = GroupedConvHostArgs; template +struct TransformConvBwdDataToGemm +{ + private: + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + static constexpr auto I3 = number<3>{}; + static constexpr auto I4 = number<4>{}; + static constexpr auto I5 = number<5>{}; +#if 0 // TODO: Enable these functionalities + template + static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths, + const ConvDimsType& strides, + index_t i) + { + long_index_t acc = 1; + for(; i < (NDimSpatial + 3); i++) + { + acc += + static_cast(lengths[i] - I1) * static_cast(strides[i]); + } + + return acc; + } + + template + static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& a_g_n_c_wis_strides, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& c_g_n_k_wos_strides) + { + const long_index_t a_element_space_size = + calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1); + const long_index_t c_element_space_size = + calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1); + const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType), + c_element_space_size * sizeof(CDataType)); + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const IndexType N = a_g_n_c_wis_lengths[I1]; + + if(element_space_size > TwoGB) + { + // Minimum divisor of N to not exceed 2GB + const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB); + + if(divisor <= static_cast(N)) + { + // Find least divisor of N larger than element_space_size / TwoGB + // Iterate up to sqrt(N). There are no divisors above this value. + for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N; + least_divisor++) + { + if(N % least_divisor == 0) + { + return N / least_divisor; + } + } + // Not found, process one Convolution N per block + return 1; + } + else + { + // Split Convolution's N dimension into N workgroups. However + // this still might not result in sufficiently small tensor, + // but at least later on we could divide the image as well. + return 1; + } + } + else + { + // Split N is not needed. + return N; + } + } +#endif + + public: + CK_TILE_HOST constexpr TransformConvBwdDataToGemm() {} + + template + CK_TILE_HOST + TransformConvBwdDataToGemm(const TransformConvBwdDataToGemmBase& transform_conv_to_gemm_base) + : G_{static_cast(transform_conv_to_gemm_base.G_)}, + N_{static_cast(transform_conv_to_gemm_base.N_)}, + Di_{static_cast(transform_conv_to_gemm_base.Di_)}, + Hi_{static_cast(transform_conv_to_gemm_base.Hi_)}, + Wi_{static_cast(transform_conv_to_gemm_base.Wi_)}, + Do_{static_cast(transform_conv_to_gemm_base.Do_)}, + Ho_{static_cast(transform_conv_to_gemm_base.Ho_)}, + Wo_{static_cast(transform_conv_to_gemm_base.Wo_)}, + Z_{static_cast(transform_conv_to_gemm_base.Z_)}, + Y_{static_cast(transform_conv_to_gemm_base.Y_)}, + X_{static_cast(transform_conv_to_gemm_base.X_)}, + K_{static_cast(transform_conv_to_gemm_base.K_)}, + C_{static_cast(transform_conv_to_gemm_base.C_)}, + ConvStrideD_{static_cast(transform_conv_to_gemm_base.ConvStrideD_)}, + ConvStrideH_{static_cast(transform_conv_to_gemm_base.ConvStrideH_)}, + ConvStrideW_{static_cast(transform_conv_to_gemm_base.ConvStrideW_)}, + ConvDilationD_{static_cast(transform_conv_to_gemm_base.ConvDilationD_)}, + ConvDilationH_{static_cast(transform_conv_to_gemm_base.ConvDilationH_)}, + ConvDilationW_{static_cast(transform_conv_to_gemm_base.ConvDilationW_)}, + InLeftPadD_{static_cast(transform_conv_to_gemm_base.InLeftPadD_)}, + InLeftPadH_{static_cast(transform_conv_to_gemm_base.InLeftPadH_)}, + InLeftPadW_{static_cast(transform_conv_to_gemm_base.InLeftPadW_)}, + InRightPadD_{static_cast(transform_conv_to_gemm_base.InRightPadD_)}, + InRightPadH_{static_cast(transform_conv_to_gemm_base.InRightPadH_)}, + InRightPadW_{static_cast(transform_conv_to_gemm_base.InRightPadW_)} + { + } + + template ::type = false> + CK_TILE_HOST TransformConvBwdDataToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads, + const ConvSpatialDimsType& tildes) + : G_{a_g_n_c_wis_lengths[I0]}, + N_{a_g_n_c_wis_lengths[I1]}, + Di_{I1}, + Hi_{I1}, + Wi_{a_g_n_c_wis_lengths[I3]}, + Do_{I1}, + Ho_{I1}, + Wo_{c_g_n_k_wos_lengths[I3]}, + Z_{I1}, + Y_{I1}, + X_{b_g_k_c_xs_lengths[I3]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + ConvStrideD_{I1}, + ConvStrideH_{I1}, + ConvStrideW_{conv_filter_strides[I0]}, + ConvDilationD_{I1}, + ConvDilationH_{I1}, + ConvDilationW_{conv_filter_dilations[I0]}, + InLeftPadD_{I0}, + InLeftPadH_{I0}, + InLeftPadW_{input_left_pads[I0]}, + InRightPadD_{I0}, + InRightPadH_{I0}, + InRightPadW_{input_right_pads[I0]}, + IdxZTilde_{I1}, + IdxYTilde_{I1}, + IdxXTilde_{tildes[I0]} + { +#if 0 // TODO: Enable these functionalities + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } +#endif + + GcdStrideDilationW_ = gcd(ConvStrideW_, ConvDilationW_); + XTilde_ = ConvStrideW_ / GcdStrideDilationW_; + WTilde_ = Wo_ + integer_divide_ceil(ConvDilationW_ * (X_ - I1), ConvStrideW_); + XDot_ = integer_divide_ceil(X_, XTilde_); + } + + template ::type = false> + CK_TILE_HOST TransformConvBwdDataToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads, + const ConvSpatialDimsType& tildes) + : G_{a_g_n_c_wis_lengths[I0]}, + N_{a_g_n_c_wis_lengths[I1]}, + Di_{I1}, + Hi_{a_g_n_c_wis_lengths[I3]}, + Wi_{a_g_n_c_wis_lengths[I4]}, + Do_{I1}, + Ho_{c_g_n_k_wos_lengths[I3]}, + Wo_{c_g_n_k_wos_lengths[I4]}, + Z_{I1}, + Y_{b_g_k_c_xs_lengths[I3]}, + X_{b_g_k_c_xs_lengths[I4]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + ConvStrideD_{I1}, + ConvStrideH_{conv_filter_strides[I0]}, + ConvStrideW_{conv_filter_strides[I1]}, + ConvDilationD_{I1}, + ConvDilationH_{conv_filter_dilations[I0]}, + ConvDilationW_{conv_filter_dilations[I1]}, + InLeftPadD_{I0}, + InLeftPadH_{input_left_pads[I0]}, + InLeftPadW_{input_left_pads[I1]}, + InRightPadD_{I0}, + InRightPadH_{input_right_pads[I0]}, + InRightPadW_{input_right_pads[I1]}, + IdxZTilde_{I1}, + IdxYTilde_{tildes[I0]}, + IdxXTilde_{tildes[I1]} + { +#if 0 // TODO: Enable these functionalities + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } +#endif + GcdStrideDilationW_ = gcd(ConvStrideW_, ConvDilationW_); + GcdStrideDilationH_ = gcd(ConvStrideH_, ConvDilationH_); + XTilde_ = ConvStrideW_ / GcdStrideDilationW_; + YTilde_ = ConvStrideH_ / GcdStrideDilationH_; + WTilde_ = Wo_ + integer_divide_ceil(ConvDilationW_ * (X_ - I1), ConvStrideW_); + HTilde_ = Ho_ + integer_divide_ceil(ConvDilationH_ * (Y_ - I1), ConvStrideH_); + XDot_ = integer_divide_ceil(X_, XTilde_); + YDot_ = integer_divide_ceil(Y_, YTilde_); + } + + template ::type = false> + CK_TILE_HOST TransformConvBwdDataToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads, + [[maybe_unused]] const ConvSpatialDimsType& tildes) + : G_{a_g_n_c_wis_lengths[I0]}, + N_{a_g_n_c_wis_lengths[I1]}, + Di_{a_g_n_c_wis_lengths[I3]}, + Hi_{a_g_n_c_wis_lengths[I4]}, + Wi_{a_g_n_c_wis_lengths[I5]}, + Do_{c_g_n_k_wos_lengths[I3]}, + Ho_{c_g_n_k_wos_lengths[I4]}, + Wo_{c_g_n_k_wos_lengths[I5]}, + Z_{b_g_k_c_xs_lengths[I3]}, + Y_{b_g_k_c_xs_lengths[I4]}, + X_{b_g_k_c_xs_lengths[I5]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + ConvStrideD_{conv_filter_strides[I0]}, + ConvStrideH_{conv_filter_strides[I1]}, + ConvStrideW_{conv_filter_strides[I2]}, + ConvDilationD_{conv_filter_dilations[I0]}, + ConvDilationH_{conv_filter_dilations[I1]}, + ConvDilationW_{conv_filter_dilations[I2]}, + InLeftPadD_{input_left_pads[I0]}, + InLeftPadH_{input_left_pads[I1]}, + InLeftPadW_{input_left_pads[I2]}, + InRightPadD_{input_right_pads[I0]}, + InRightPadH_{input_right_pads[I1]}, + InRightPadW_{input_right_pads[I2]}, + IdxZTilde_{tildes[I0]}, + IdxYTilde_{tildes[I1]}, + IdxXTilde_{tildes[I2]} + { +#if 0 // TODO: Enable these functionalities + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } +#endif + GcdStrideDilationW_ = gcd(ConvStrideW_, ConvDilationW_); + GcdStrideDilationH_ = gcd(ConvStrideH_, ConvDilationH_); + GcdStrideDilationD_ = gcd(ConvStrideD_, ConvDilationD_); + XTilde_ = ConvStrideW_ / GcdStrideDilationW_; + YTilde_ = ConvStrideH_ / GcdStrideDilationH_; + ZTilde_ = ConvStrideD_ / GcdStrideDilationD_; + WTilde_ = Wo_ + integer_divide_ceil(ConvDilationW_ * (X_ - I1), ConvStrideW_); + HTilde_ = Ho_ + integer_divide_ceil(ConvDilationH_ * (Y_ - I1), ConvStrideH_); + DTilde_ = Do_ + integer_divide_ceil(ConvDilationD_ * (Z_ - I1), ConvStrideD_); + XDot_ = integer_divide_ceil(X_, XTilde_); + YDot_ = integer_divide_ceil(Y_, YTilde_); + ZDot_ = integer_divide_ceil(Z_, ZTilde_); + } + +#if 0 // TODO: Enable these functionalities + __host__ bool AreDescriptorsSmallerThan2GB() const + { + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const long_index_t in_desc_space_size = + I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ + + (Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_; + const long_index_t out_desc_space_size = + I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ + + (Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_; + + bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB; + bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB; + + return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB; + } + + __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base, + CDataType* c_grid_ptr_base) const + { + // Create copies + auto conv_to_gemm_transformer_left = *this; + auto conv_to_gemm_transformer_right = *this; + IndexType a_right_offset = 0; + IndexType c_right_offset = 0; + // Calculate real filter size + const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1; + const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1; + const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1; + // Calculate start position in input for right tensor + const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_; + const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_; + const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_; + // Calculate last position in input for left tensor + const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff; + const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff; + const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff; + // Allow to split if whole left padding will be in left tensor and right padding in right + // tensor + const bool is_possible_to_split_d = Do_ != 1 && + di_right_transformer_start_idx > InLeftPadD_ && + di_left_transformer_end_idx <= (InLeftPadD_ + Di_); + const bool is_possible_to_split_h = Ho_ != 1 && + hi_right_transformer_start_idx > InLeftPadH_ && + hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_); + const bool is_possible_to_split_w = Wo_ != 1 && + wi_right_transformer_start_idx > InLeftPadW_ && + wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_); + + if(is_possible_to_split_d) + { + // Apply new sizes + // Split output on half + conv_to_gemm_transformer_left.Do_ = Do_ / 2; + conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2; + // Assign left padding to left convolution + conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_; + conv_to_gemm_transformer_right.InLeftPadD_ = 0; + // Assign right padding to right convolution + conv_to_gemm_transformer_left.InRightPadD_ = 0; + conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_; + // Calculate new input size + conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_; + conv_to_gemm_transformer_right.Di_ = + math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_), + (conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff); + ; + // Calcualte offsets + a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_; + c_right_offset = (Do_ / 2) * DoStride_; + } + else if(is_possible_to_split_h) + { + conv_to_gemm_transformer_left.Ho_ = Ho_ / 2; + conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2; + + conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_; + conv_to_gemm_transformer_right.InLeftPadH_ = 0; + + conv_to_gemm_transformer_left.InRightPadH_ = 0; + conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_; + + conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_; + conv_to_gemm_transformer_right.Hi_ = + math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_), + (conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff); + a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_; + c_right_offset = (Ho_ / 2) * HoStride_; + } + else if(is_possible_to_split_w) + { + conv_to_gemm_transformer_left.Wo_ = Wo_ / 2; + conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2; + + conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_; + conv_to_gemm_transformer_right.InLeftPadW_ = 0; + + conv_to_gemm_transformer_left.InRightPadW_ = 0; + conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_; + + conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_; + conv_to_gemm_transformer_right.Wi_ = + math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_), + (conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff); + + a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_; + c_right_offset = (Wo_ / 2) * WoStride_; + } + // Return left transform, right transformer, right offset to Input and right offset to + // Output + return ck_tile::make_tuple(conv_to_gemm_transformer_left, + conv_to_gemm_transformer_right, + a_grid_ptr_base + a_right_offset, + c_grid_ptr_base + c_right_offset); + } +#endif + + template ::type = false> + CK_TILE_HOST auto make_out_grid_desc() const + { + // NWGK + const index_t NStride = Wo_ * G_ * K_; + const index_t WoStride = G_ * K_; + constexpr auto KStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + + return make_naive_tensor_descriptor(make_tuple(N_, Wo_, K_), + make_tuple(NStride, WoStride, KStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_wei_grid_desc() const + { + // GKXC + return make_naive_tensor_descriptor_packed(make_tuple(K_, X_, C_)); + } + + template ::type = false> + CK_TILE_HOST auto make_in_grid_desc() const + { + // NWGC + const index_t NStride = Wi_ * G_ * C_; + const index_t WiStride = G_ * C_; // GC? + constexpr auto CStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_), + make_tuple(NStride, WiStride, CStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_out_grid_desc() const + { + // NHWGK + const index_t NStride = Ho_ * Wo_ * G_ * K_; + const index_t HoStride = Wo_ * G_ * K_; + const index_t WoStride = G_ * K_; + constexpr auto KStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + + return make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, K_), + make_tuple(NStride, HoStride, WoStride, KStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_in_grid_desc() const + { + // NHWGC + const index_t NStride = Hi_ * Wi_ * G_ * C_; + const index_t HiStride = Wi_ * G_ * C_; + const index_t WiStride = G_ * C_; + constexpr auto CStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStride, HiStride, WiStride, CStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_wei_grid_desc() const + { + // GKYXC + return make_naive_tensor_descriptor_packed(make_tuple(K_, Y_, X_, C_)); + } + + template ::type = false> + CK_TILE_HOST auto make_out_grid_desc() const + { + // NDHWGK + const index_t NStride = Do_ * Ho_ * Wo_ * G_ * K_; + const index_t DoStride = Ho_ * Wo_ * G_ * K_; + const index_t HoStride = Wo_ * G_ * K_; + const index_t WoStride = G_ * K_; + constexpr auto KStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor( + make_tuple(N_, Do_, Ho_, Wo_, K_), + make_tuple(NStride, DoStride, HoStride, WoStride, KStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_in_grid_desc() const + { + const index_t NStride = Di_ * Hi_ * Wi_ * G_ * C_; + const index_t DiStride = Hi_ * Wi_ * G_ * C_; + const index_t HiStride = Wi_ * G_ * C_; + const index_t WiStride = G_ * C_; + constexpr auto CStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, C_), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_wei_grid_desc() const + { + // GKZYXC + return make_naive_tensor_descriptor_packed(make_tuple(K_, Z_, Y_, X_, C_)); + } + // TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as + // properties + + template ::type = false> + CK_TILE_HOST auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N([[maybe_unused]] const index_t GemmKBatch) const + { + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IWTildeSliceBegin = integer_divide_floor( + max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_); + + const auto IWTildeSliceEnd = + min(WTilde_, integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1); + + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto XDotSlice = integer_divide_ceil(X_ - IdxXTilde_, XTilde_); + + const auto out_grid_desc = make_out_grid_desc(); + const auto in_grid_desc = make_in_grid_desc(); + const auto wei_grid_desc = make_wei_grid_desc(); + + // A: output tensor comes in K_M + const auto out_n_wop_k_grid_desc = + transform_tensor_descriptor(out_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wo_, I0, I0), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + const auto out_n_xdotslice_wtildeslice_k_grid_desc = transform_tensor_descriptor( + out_n_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( + out_n_xdotslice_wtildeslice_k_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K_)), + make_merge_transform(make_tuple(N_, WTildeSlice))), + make_tuple(sequence<1, 3>{}, sequence<0, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // B: weight tensor comes in K_N + const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_embed_transform(make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + const auto wei_k_xdotslice_c_grid_desc = transform_tensor_descriptor( + wei_k_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<>{}, sequence<2>{})); + + const auto wei_gemmn_gemmkraw_grid_desc = + transform_tensor_descriptor(wei_k_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K_)), + make_pass_through_transform(C_)), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // c: input + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<>{}, sequence<1>{}, sequence<2>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, WTildeSlice)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tuple(out_gemmm_gemmkraw_grid_desc, + wei_gemmn_gemmkraw_grid_desc, + in_gemmmraw_gemmnraw_grid_desc); + } + + template ::type = false> + CK_TILE_HOST auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N([[maybe_unused]] const index_t GemmKBatch) const + { + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = integer_divide_floor( + max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_); + const auto IWTildeSliceBegin = integer_divide_floor( + max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_); + + const auto IHTildeSliceEnd = + min(HTilde_, integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1); + const auto IWTildeSliceEnd = + min(WTilde_, integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = integer_divide_ceil(Y_ - IdxYTilde_, YTilde_); + const auto XDotSlice = integer_divide_ceil(X_ - IdxXTilde_, XTilde_); + + const auto out_grid_desc = make_out_grid_desc(); + const auto in_grid_desc = make_in_grid_desc(); + const auto wei_grid_desc = make_wei_grid_desc(); + + // A: output tensor comes in K_M + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(YDot_, HTilde_), + make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{})); + + const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), + make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))), + make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // B: weight tensor comes in K_N + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_embed_transform(make_tuple(YDot_, YTilde_), + make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + const auto wei_k_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxYTilde_), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<3>{}, + sequence<2>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<>{}, + sequence<>{}, + sequence<3>{})); + + const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor( + wei_k_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), + make_pass_through_transform(C_)), + make_tuple(sequence<1, 2, 0>{}, sequence<3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // c: input + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(YTilde_, HTilde_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_freeze_transform(IdxYTilde_), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<>{}, + sequence<1>{}, + sequence<>{}, + sequence<2>{}, + sequence<3>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tuple(out_gemmm_gemmkraw_grid_desc, + wei_gemmn_gemmkraw_grid_desc, + in_gemmmraw_gemmnraw_grid_desc); + } + + template ::type = false> + CK_TILE_HOST auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N([[maybe_unused]] const index_t GemmKBatch) const + { + // only work on DTilde, HTilde and WTilde that contribute to non-padding area of input + // tensor + const auto IDTildeSliceBegin = integer_divide_floor( + max(I0, InLeftPadD_ - ConvDilationD_ * (ZTilde_ - I1)), ConvStrideD_); + const auto IHTildeSliceBegin = integer_divide_floor( + max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_); + const auto IWTildeSliceBegin = integer_divide_floor( + max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_); + + const auto IDTildeSliceEnd = + min(DTilde_, integer_divide_ceil(InLeftPadD_ + Di_ - I1, ConvStrideD_) + I1); + const auto IHTildeSliceEnd = + min(HTilde_, integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1); + const auto IWTildeSliceEnd = + min(WTilde_, integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1); + + const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto ZDotSlice = integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_); + const auto YDotSlice = integer_divide_ceil(Y_ - IdxYTilde_, YTilde_); + const auto XDotSlice = integer_divide_ceil(X_ - IdxXTilde_, XTilde_); + + const auto out_grid_desc = make_out_grid_desc(); + const auto in_grid_desc = make_in_grid_desc(); + const auto wei_grid_desc = make_wei_grid_desc(); + + // A: output tensor comes in K_M + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Do_, I0, I0), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(ZDot_, DTilde_), + make_tuple(-ConvDilationD_ / GcdStrideDilationD_, I1)), + make_embed_transform(make_tuple(YDot_, HTilde_), + make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_slice_transform(ZDot_, I0, ZDotSlice), + make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}, + sequence<6>{}, + sequence<7>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}, + sequence<6>{}, + sequence<7>{})); + + const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), + make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice))), + make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // B: weight tensor comes in K_N + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_embed_transform(make_tuple(ZDot_, ZTilde_), + make_tuple(ConvStrideD_ / GcdStrideDilationD_, I1)), + make_embed_transform(make_tuple(YDot_, YTilde_), + make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + const auto wei_k_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_slice_transform(ZDot_, I0, ZDotSlice), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxZTilde_), + make_freeze_transform(IdxYTilde_), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<3>{}, + sequence<5>{}, + sequence<2>{}, + sequence<4>{}, + sequence<6>{}, + sequence<7>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<>{}, + sequence<>{}, + sequence<>{}, + sequence<4>{})); + + const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor( + wei_k_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), + make_pass_through_transform(C_)), + make_tuple(sequence<1, 2, 3, 0>{}, sequence<4>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // c: input + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(ZTilde_, DTilde_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(YTilde_, HTilde_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_freeze_transform(IdxZTilde_), + make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(IdxYTilde_), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}, + sequence<6>{}, + sequence<7>{}), + make_tuple(sequence<0>{}, + sequence<>{}, + sequence<1>{}, + sequence<>{}, + sequence<2>{}, + sequence<>{}, + sequence<3>{}, + sequence<4>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tuple(out_gemmm_gemmkraw_grid_desc, + wei_gemmn_gemmkraw_grid_desc, + in_gemmmraw_gemmnraw_grid_desc); + } + + IndexType G_, N_; + IndexType Di_, Hi_, Wi_; + IndexType Do_, Ho_, Wo_; + IndexType Z_, Y_, X_; + IndexType K_, C_; + IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_; + IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_; + IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_; + IndexType InRightPadD_, InRightPadH_, InRightPadW_; + IndexType IdxZTilde_, IdxYTilde_, IdxXTilde_; + IndexType GcdStrideDilationD_, GcdStrideDilationH_, GcdStrideDilationW_; + IndexType ZTilde_, YTilde_, XTilde_; + IndexType DTilde_, HTilde_, WTilde_; + IndexType ZDot_, YDot_, XDot_; +}; + +} // namespace ck_tile