From 2567f5e538ded308979d6fd4803a4ba1345ac926 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Sat, 21 Jun 2025 00:44:36 +0200 Subject: [PATCH] [CK TILE] Grouped Convolution Forward Kernel (#2188) * [CK TILE] Grouped Convolution Forward Kernel * custom vector size * fixes * refactor * rebase fixes * fixes * fixes [ROCm/composable_kernel commit: cebdee4d9ee9533b9928a0c1a4c155f6693607c6] --- .../20_grouped_convolution/CMakeLists.txt | 4 + .../grouped_convolution_forward.cpp | 207 +++ .../grouped_convolution_utils.hpp | 108 ++ .../run_grouped_convolution_example.inc | 206 +++ example/ck_tile/CMakeLists.txt | 1 + include/ck_tile/host.hpp | 1 + .../reference/reference_grouped_conv_fwd.hpp | 165 ++ .../ops/epilogue/cshuffle_epilogue.hpp | 12 +- .../gemm_pipeline_ag_bg_cr_comp_v5.hpp | 3 +- ...ine_agmem_bgmem_creg_v1_default_policy.hpp | 4 +- .../gemm/pipeline/gemm_pipeline_problem.hpp | 38 +- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 18 +- include/ck_tile/ops/grouped_convolution.hpp | 12 + .../grouped_convolution_forward_kernel.hpp | 800 +++++++++ .../utils/convolution_specialization.hpp | 30 + .../utils/grouped_convolution_utils.hpp | 74 + .../utils/transform_conv_fwd_to_gemm.hpp | 1432 +++++++++++++++++ 17 files changed, 3096 insertions(+), 19 deletions(-) create mode 100644 example/ck_tile/20_grouped_convolution/CMakeLists.txt create mode 100644 example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp create mode 100644 example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp create mode 100644 example/ck_tile/20_grouped_convolution/run_grouped_convolution_example.inc create mode 100644 include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp create mode 100644 include/ck_tile/ops/grouped_convolution.hpp create mode 100644 include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp create mode 100644 include/ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp create mode 100644 include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp create mode 100644 include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp diff --git a/example/ck_tile/20_grouped_convolution/CMakeLists.txt b/example/ck_tile/20_grouped_convolution/CMakeLists.txt new file mode 100644 index 0000000000..00cb0ab9e5 --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/CMakeLists.txt @@ -0,0 +1,4 @@ +add_executable(tile_example_grouped_conv_fwd EXCLUDE_FROM_ALL grouped_convolution_forward.cpp) +set(EXAMPLE_CONV_COMPILE_OPTIONS) +list(APPEND EXAMPLE_CONV_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) +target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp new file mode 100644 index 0000000000..685fdccde2 --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "grouped_convolution_utils.hpp" + +template , + typename DsLayout = ck_tile::tuple<>, + typename CDEElementWise = ck_tile::element_wise::PassThrough> +float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s) +{ + constexpr int kBlockPerCu = 1; + + constexpr ck_tile::index_t M_Tile = 64; + constexpr ck_tile::index_t N_Tile = 64; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr ck_tile::index_t VectorSizeA = 8; + constexpr ck_tile::index_t VectorSizeB = 8; + constexpr ck_tile::index_t VectorSizeC = 8; + + // Implicit GEMM Traits + using CodegenShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + using GroupedConvTraitsType = + ck_tile::GroupedConvTraits; + using CodegenPipelineProblem = + ck_tile::GemmPipelineProblem; + using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + using ConvEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GroupedConvolutionForwardKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << CodegenPipeline::GetVectorSizeA() + << ", Vector size B: " << CodegenPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + + return Run(ck_tile::integral_constant{}); +} + +#include "run_grouped_convolution_example.inc" + +template +int run_grouped_conv_fwd_example_prec_type( + std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[]) +{ + using NWGC = ck_tile::tensor_layout::convolution::NWGC; + using NHWGC = ck_tile::tensor_layout::convolution::NHWGC; + using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC; + + using GKXC = ck_tile::tensor_layout::convolution::GKXC; + using GKYXC = ck_tile::tensor_layout::convolution::GKYXC; + using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC; + + using NWGK = ck_tile::tensor_layout::convolution::NWGK; + using NHWGK = ck_tile::tensor_layout::convolution::NHWGK; + using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK; + + if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK") + { + return run_grouped_conv_fwd_example_with_layouts{}, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NWGC{}, GKXC{}, NWGK{}); + } + else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") + { + return run_grouped_conv_fwd_example_with_layouts{}, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NHWGC{}, GKYXC{}, NHWGK{}); + } + else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "GKZYXC") + { + return run_grouped_conv_fwd_example_with_layouts{}, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{}); + } + else + { + throw std::runtime_error("Unsupported memory layout!"); + } +} + +int run_grouped_conv_fwd_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + std::string data_type = arg_parser.get_str("prec"); + std::string in_layout = arg_parser.get_str("in_layout"); + std::string wei_layout = arg_parser.get_str("weight_layout"); + std::string out_layout = arg_parser.get_str("out_layout"); + + if(data_type == "fp16") + { + return run_grouped_conv_fwd_example_prec_type( + in_layout, wei_layout, out_layout, argc, argv); + } + else if(data_type == "bf16") + { + return run_grouped_conv_fwd_example_prec_type( + in_layout, wei_layout, out_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } +} + +int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_example(argc, argv); } diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp new file mode 100644 index 0000000000..cc8d365b18 --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp @@ -0,0 +1,108 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" + +ck_tile::index_t fill_spatial_dimensions(std::vector& filter_spatial_lengths, + std::vector& image_spatial_lengths, + std::vector& strides, + std::vector& dilations, + std::vector& lpads, + std::vector& rpads, + ck_tile::ArgParser& arg_parser) +{ + + constexpr ck_tile::index_t non_sp_dims = 3; + const ck_tile::index_t n_dim_sp = arg_parser.get_str("in_layout").size() - non_sp_dims; + + if(!(n_dim_sp >= 1 && n_dim_sp <= 3)) + { + throw std::runtime_error("Wrong layout!\n"); + } + + if(n_dim_sp == 3) + { + filter_spatial_lengths.push_back(arg_parser.get_int("z")); + image_spatial_lengths.push_back(arg_parser.get_int("d")); + strides.push_back(arg_parser.get_int("stride_d")); + dilations.push_back(arg_parser.get_int("dilation_d")); + lpads.push_back(arg_parser.get_int("lpad_d")); + rpads.push_back(arg_parser.get_int("rpad_d")); + } + if(n_dim_sp >= 2) + { + filter_spatial_lengths.push_back(arg_parser.get_int("y")); + image_spatial_lengths.push_back(arg_parser.get_int("h")); + strides.push_back(arg_parser.get_int("stride_h")); + dilations.push_back(arg_parser.get_int("dilation_h")); + lpads.push_back(arg_parser.get_int("lpad_h")); + rpads.push_back(arg_parser.get_int("rpad_h")); + } + filter_spatial_lengths.push_back(arg_parser.get_int("x")); + image_spatial_lengths.push_back(arg_parser.get_int("w")); + strides.push_back(arg_parser.get_int("stride_w")); + dilations.push_back(arg_parser.get_int("dilation_w")); + lpads.push_back(arg_parser.get_int("lpad_w")); + rpads.push_back(arg_parser.get_int("rpad_w")); + + return n_dim_sp; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("g", "2", "group dimension") + .insert("n", "32", "n dimension") + .insert("k", "32", "k dimension") + .insert("c", "32", "c dimension") + + .insert("d", "64", "d dimension") + .insert("h", "64", "h dimension") + .insert("w", "64", "w dimension") + + .insert("z", "4", "z dimension") + .insert("y", "4", "y dimension") + .insert("x", "4", "x dimension") + + .insert("stride_d", "1", "d stride") + .insert("stride_h", "1", "h stride") + .insert("stride_w", "1", "w stride") + + .insert("dilation_d", "1", "d dilation") + .insert("dilation_h", "1", "h dilation") + .insert("dilation_w", "1", "w dilation") + + .insert("lpad_d", "0", "left pad for d dimension") + .insert("lpad_h", "0", "left pad for h dimension") + .insert("lpad_w", "0", "left pad for w dimension") + + .insert("rpad_d", "0", "right pad for d dimension") + .insert("rpad_h", "0", "right pad for h dimension") + .insert("rpad_w", "0", "right pad for w dimension") + + .insert("in_layout", "NHWGC", "Input image layout - NHWGC by default") + .insert("weight_layout", "GKYXC", "Weight layout - GKYXC by default") + .insert("out_layout", "NHWGK", "Output image layout - NHWGK by default") + .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// host API +float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_example.inc new file mode 100644 index 0000000000..ed72eb354d --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_example.inc @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +template +auto calculate_rtol_atol(const ck_tile::index_t GemmK, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(GemmK, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(GemmK, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = + ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_grouped_conv_fwd(ck_tile::GroupedConvHostArgs& args, int n_warmup, int n_repeat) +{ + float ave_time = grouped_conv_fwd( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::size_t flop = args.GetFlops(); + std::size_t num_byte = args.GetByte(); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; + + return ave_time; +} + +template +int run_grouped_conv_fwd_example_with_layouts( + int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using AccDataType = float; + + std::vector filter_spatial_lengths; + std::vector image_spatial_lengths; + std::vector strides; + std::vector dilations; + std::vector lpads; + std::vector rpads; + + const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths, + image_spatial_lengths, + strides, + dilations, + lpads, + rpads, + arg_parser); + + ck_tile::conv::ConvParam conv_param{num_dim_sp, + arg_parser.get_int("g"), + arg_parser.get_int("n"), + arg_parser.get_int("k"), + arg_parser.get_int("c"), + filter_spatial_lengths, + image_spatial_lengths, + strides, + dilations, + lpads, + rpads}; + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + + const auto in_g_n_c_wis_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + const auto wei_g_k_c_xs_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + const auto out_g_n_k_wos_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + ck_tile::HostTensor input(in_g_n_c_wis_desc); + ck_tile::HostTensor weight(wei_g_k_c_xs_desc); + ck_tile::HostTensor output(out_g_n_k_wos_desc); + + if(init_method == 0) + { + ck_tile::FillUniformDistribution{-5.f, 5.f}(input); + ck_tile::FillUniformDistribution{-5.f, 5.f}(weight); + } + else if(init_method == 1) + { + ck_tile::FillMonotonicSeq{}(input); + ck_tile::FillMonotonicSeq{}(weight); + } + else if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(input); + ck_tile::FillUniformDistribution{1.f, 1.f}(weight); + } + else + { + input.SetZero(); + weight.SetZero(); + } + + ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes()); + + input_dev_buf.ToDevice(input.data()); + weight_dev_buf.ToDevice(weight.data()); + output_dev_buf.SetZero(); + + ck_tile::GroupedConvHostArgs args(conv_param, + input_dev_buf.GetDeviceBuffer(), + weight_dev_buf.GetDeviceBuffer(), + {}, + output_dev_buf.GetDeviceBuffer(), + kbatch); + + std::cout << "Run Grouped Conv Fwd kernel" << std::endl; + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "output: " << output.mDesc << std::endl; + + invoke_grouped_conv_fwd(args, n_warmup, n_repeat); + + output_dev_buf.FromDevice(output.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor output_host_ref(out_g_n_k_wos_desc); + output_host_ref.SetZero(); + + ck_tile::reference_grouped_conv_fwd( + input, + weight, + output_host_ref, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_); + const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_); + const float max_accumulated_value = + *std::max_element(output_host_ref.mData.begin(), output_host_ref.mData.end()); + const auto rtol_atol = + calculate_rtol_atol( + GemmK, kbatch, max_accumulated_value); + pass = ck_tile::check_err(output, + output_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + throw std::runtime_error("Unsupported gpu verification !!!"); + } + + return pass; +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 92b859a750..8989060842 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -19,6 +19,7 @@ add_subdirectory(16_batched_gemm) add_subdirectory(17_grouped_gemm) add_subdirectory(18_flatmm) add_subdirectory(19_gemm_multi_d) +add_subdirectory(20_grouped_convolution) add_subdirectory(35_batched_transpose) add_subdirectory(36_copy) add_subdirectory(37_transpose) diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 44851fec4a..4a9748fcbb 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -27,6 +27,7 @@ #include "ck_tile/host/reference/reference_elementwise.hpp" #include "ck_tile/host/reference/reference_fused_moe.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" +#include "ck_tile/host/reference/reference_grouped_conv_fwd.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_moe_sorting.hpp" diff --git a/include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp b/include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp new file mode 100644 index 0000000000..8a12fdb7e0 --- /dev/null +++ b/include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor& input, + const HostTensor& weight, + HostTensor& output, + std::vector conv_strides, + std::vector conv_dilations, + std::vector in_left_pads, + std::vector) +{ + if(!(input.get_num_of_dimension() == NDimSpatial + 3 && + weight.get_num_of_dimension() == NDimSpatial + 3 && + output.get_num_of_dimension() == NDimSpatial + 3)) + { + throw std::runtime_error("wrong! inconsistent dimension"); + } + + if constexpr(NDimSpatial == 1) + { + auto func = [&](auto g, auto n, auto k, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c) + { + for(std::size_t x = 0; x < weight.get_lengths()[3]; ++x) + { + auto wi = static_cast(wo * conv_strides[0]) + + static_cast(x * conv_dilations[0]) - + static_cast(in_left_pads[0]); + + if(wi >= 0 && ck_tile::type_convert(wi) < input.get_lengths()[3]) + { + InDataType v_in = input(g, n, c, wi); + WeiDataType v_wei = weight(g, k, c, x); + v_acc += ck_tile::type_convert(v_in) * + ck_tile::type_convert(v_wei); + } + } + } + OutDataType v_acc_converted = ck_tile::type_convert(v_acc); + output(g, n, k, wo) = v_acc_converted; + }; + + make_ParallelTensorFunctor(func, + output.get_lengths()[0], + output.get_lengths()[1], + output.get_lengths()[2], + output.get_lengths()[3])(std::thread::hardware_concurrency()); + } + else if constexpr(NDimSpatial == 2) + { + auto func = [&](auto g, auto n, auto k, auto ho, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c) + { + for(std::size_t y = 0; y < weight.get_lengths()[3]; ++y) + { + auto hi = static_cast(ho * conv_strides[0]) + + static_cast(y * conv_dilations[0]) - + static_cast(in_left_pads[0]); + + for(std::size_t x = 0; x < weight.get_lengths()[4]; ++x) + { + auto wi = static_cast(wo * conv_strides[1]) + + static_cast(x * conv_dilations[1]) - + static_cast(in_left_pads[1]); + + if(hi >= 0 && + ck_tile::type_convert(hi) < input.get_lengths()[3] && + wi >= 0 && + ck_tile::type_convert(wi) < input.get_lengths()[4]) + { + InDataType v_in = input(g, n, c, hi, wi); + WeiDataType v_wei = weight(g, k, c, y, x); + + v_acc += ck_tile::type_convert(v_in) * + ck_tile::type_convert(v_wei); + } + } + } + } + OutDataType v_acc_converted = ck_tile::type_convert(v_acc); + output(g, n, k, ho, wo) = v_acc_converted; + }; + + make_ParallelTensorFunctor(func, + output.get_lengths()[0], + output.get_lengths()[1], + output.get_lengths()[2], + output.get_lengths()[3], + output.get_lengths()[4])(std::thread::hardware_concurrency()); + } + else if constexpr(NDimSpatial == 3) + { + auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c) + { + for(std::size_t z = 0; z < weight.get_lengths()[3]; ++z) + { + auto di = static_cast(d_o * conv_strides[0]) + + static_cast(z * conv_dilations[0]) - + static_cast(in_left_pads[0]); + for(std::size_t y = 0; y < weight.get_lengths()[4]; ++y) + { + auto hi = static_cast(ho * conv_strides[1]) + + static_cast(y * conv_dilations[1]) - + static_cast(in_left_pads[1]); + for(std::size_t x = 0; x < weight.get_lengths()[5]; ++x) + { + auto wi = static_cast(wo * conv_strides[2]) + + static_cast(x * conv_dilations[2]) - + static_cast(in_left_pads[2]); + if(di >= 0 && + ck_tile::type_convert(di) < input.get_lengths()[3] && + hi >= 0 && + ck_tile::type_convert(hi) < input.get_lengths()[4] && + wi >= 0 && + ck_tile::type_convert(wi) < input.get_lengths()[5]) + { + InDataType v_in = input(g, n, c, di, hi, wi); + WeiDataType v_wei = weight(g, k, c, z, y, x); + + v_acc += ck_tile::type_convert(v_in) * + ck_tile::type_convert(v_wei); + } + } + } + } + } + OutDataType v_acc_converted = ck_tile::type_convert(v_acc); + output(g, n, k, d_o, ho, wo) = v_acc_converted; + }; + + make_ParallelTensorFunctor(func, + output.get_lengths()[0], + output.get_lengths()[1], + output.get_lengths()[2], + output.get_lengths()[3], + output.get_lengths()[4], + output.get_lengths()[5])(std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error("Ref_Conv_fwd: number of dimensions must be between 1 and 3."); + } +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 68e91520bf..bf58544259 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -27,7 +27,9 @@ template + index_t kNumWaveGroups_ = 1, + bool FixedVectorSize_ = false, + index_t VectorSizeC_ = 1> struct CShuffleEpilogueProblem { using ADataType = remove_cvref_t; @@ -48,6 +50,8 @@ struct CShuffleEpilogueProblem static constexpr index_t KPerXdl = KPerXdl_; static constexpr index_t isCTransposed = isCTransposed_; static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; + static constexpr bool FixedVectorSize = FixedVectorSize_; + static constexpr index_t VectorSizeC = VectorSizeC_; static constexpr index_t kNumWaveGroups = kNumWaveGroups_; static constexpr index_t NumDTensor = DsDataType::size(); @@ -80,6 +84,8 @@ struct CShuffleEpilogue static constexpr index_t NPerXdl = Problem::NPerXdl; static constexpr index_t KPerXdl = Problem::KPerXdl; static constexpr index_t isCTransposed = Problem::isCTransposed; + static constexpr bool FixedVectorSize = Problem::FixedVectorSize; + static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t MPerIteration = MPerXdl * MWave; static constexpr index_t NPerIteration = NPerXdl * NWave; static constexpr index_t NumDTensor = Problem::NumDTensor; @@ -98,6 +104,10 @@ struct CShuffleEpilogue */ CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC() { + if constexpr(FixedVectorSize) + { + return VectorSizeC; + } constexpr index_t max_vector_size = 16; if constexpr(std::is_same_v) { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp index 55220730cd..424565060b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -1,8 +1,7 @@ #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" #include "ck_tile/host/concat.hpp" namespace ck_tile { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index 6bb14af9e6..0f7f6369f0 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -121,7 +121,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy if constexpr(std::is_same_v) { - constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M1 = Problem::VectorSizeA; constexpr index_t M0 = MPerBlock / M1; constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; static_assert(total_pixels % M1 == 0); @@ -211,7 +211,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy if constexpr(std::is_same_v) { - constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t N1 = Problem::VectorSizeB; constexpr index_t N0 = NPerBlock / N1; constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; static_assert(total_pixels % N1 == 0); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index b10ee0320f..dc7d150b46 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -14,7 +14,10 @@ template + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1> struct GemmPipelineProblemBase { using Traits = remove_cvref_t; @@ -24,6 +27,8 @@ struct GemmPipelineProblemBase using CDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; + static constexpr bool FixedVectorSize = FixedVectorSize_; + using BlockGemmShape = remove_cvref_t; using ALayout = remove_cvref_t; @@ -115,7 +120,11 @@ struct GemmPipelineProblemBase } static constexpr index_t VectorSizeA = []() { - if constexpr(std::is_same_v) + if constexpr(FixedVectorSize) + { + return VectorSizeA_; + } + else if constexpr(std::is_same_v) { return kPadK ? 1 : GetAlignmentA(); } @@ -126,7 +135,11 @@ struct GemmPipelineProblemBase }(); static constexpr index_t VectorSizeB = []() { - if constexpr(std::is_same_v) + if constexpr(FixedVectorSize) + { + return VectorSizeB_; + } + else if constexpr(std::is_same_v) { return kPadN ? 1 : GetAlignmentB(); } @@ -153,13 +166,19 @@ template + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1> using GemmPipelineProblem = GemmPipelineProblemBase; + ComputeDataType_, + FixedVectorSize_, + VectorSizeA_, + VectorSizeB_>; template + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1> struct UniversalGemmPipelineProblem { using Traits = remove_cvref_t; @@ -179,6 +201,10 @@ struct UniversalGemmPipelineProblem using CDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; + static constexpr bool FixedVectorSize = FixedVectorSize_; + static constexpr index_t VectorSizeA = VectorSizeA_; + static constexpr index_t VectorSizeB = VectorSizeB_; + using BlockGemmShape = remove_cvref_t; using ALayout = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 91e845d200..d5f2eedf2d 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -426,10 +426,11 @@ struct UniversalGemmBasePolicy { using ALayout = remove_cvref_t; - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = GetVectorSizeA(); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA(); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; // Tile: MPerBlock X KPerBlock @@ -461,10 +462,11 @@ struct UniversalGemmBasePolicy { using BLayout = remove_cvref_t; - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = GetVectorSizeB(); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; // Tile: KPerBlock X NPerBlock diff --git a/include/ck_tile/ops/grouped_convolution.hpp b/include/ck_tile/ops/grouped_convolution.hpp new file mode 100644 index 0000000000..ae5720776c --- /dev/null +++ b/include/ck_tile/ops/grouped_convolution.hpp @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp" +#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" +#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" +#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp new file mode 100644 index 0000000000..196c468c07 --- /dev/null +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -0,0 +1,800 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/core/utility/env.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" +#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" + +namespace ck_tile { + +/// @brief The Grouped Convolution kernel device arguments. +template +struct GroupedConvFwdKernelArgs +{ + + using ConvToGemmFwdTransformer = + TransformConvFwdToGemm; + static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; + + template < + typename InLay = typename GroupedConvTraitsType::InLayout, + typename WeiLay = typename GroupedConvTraitsType::WeiLayout, + typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename std::enable_if && + std::is_same_v && + std::is_same_v, + bool>::type = false> + CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args) + { + in_g_n_c_wis_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0])}; + wei_g_k_c_xs_lengths = {static_cast(args.G_), + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0])}; + out_g_n_k_wos_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.K_), + static_cast(args.output_spatial_lengths_[0])}; + + conv_filter_strides = {static_cast(args.conv_filter_strides_[0])}; + conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0])}; + input_left_pads = {static_cast(args.input_left_pads_[0])}; + input_right_pads = {static_cast(args.input_right_pads_[0])}; + + k_batch = args.k_batch; + + GemmM = args.N_ * args.output_spatial_lengths_[0]; + GemmN = args.K_; + GemmK = args.C_ * args.filter_spatial_lengths_[0]; + + in_ptr = args.in_ptr; + wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } + out_ptr = args.out_ptr; + + ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, + wei_g_k_c_xs_lengths, + out_g_n_k_wos_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + a_grid_desc_m_k = + conv_to_gemm_transformer + .template MakeADescriptor_M_K(); + b_grid_desc_n_k = + conv_to_gemm_transformer + .template MakeBDescriptor_N_K(); + c_grid_desc_m_n = + conv_to_gemm_transformer + .template MakeCDescriptor_M_N(); + + group_stride_a = args.C_; + group_stride_b = args.K_ * args.C_ * + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); + group_stride_c = args.K_; + } + + template < + typename InLay = typename GroupedConvTraitsType::InLayout, + typename WeiLay = typename GroupedConvTraitsType::WeiLayout, + typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename std::enable_if && + std::is_same_v && + std::is_same_v, + bool>::type = false> + CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args) + { + in_g_n_c_wis_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1])}; + wei_g_k_c_xs_lengths = {static_cast(args.G_), + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1])}; + out_g_n_k_wos_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.K_), + static_cast(args.output_spatial_lengths_[0]), + static_cast(args.output_spatial_lengths_[1])}; + + conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), + static_cast(args.conv_filter_strides_[1])}; + conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), + static_cast(args.conv_filter_dilations_[1])}; + input_left_pads = {static_cast(args.input_left_pads_[0]), + static_cast(args.input_left_pads_[1])}; + input_right_pads = {static_cast(args.input_right_pads_[0]), + static_cast(args.input_right_pads_[1])}; + + k_batch = args.k_batch; + + GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; + GemmN = args.K_; + GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1]; + + in_ptr = args.in_ptr; + wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } + out_ptr = args.out_ptr; + + ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, + wei_g_k_c_xs_lengths, + out_g_n_k_wos_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + a_grid_desc_m_k = + conv_to_gemm_transformer + .template MakeADescriptor_M_K(); + b_grid_desc_n_k = + conv_to_gemm_transformer + .template MakeBDescriptor_N_K(); + c_grid_desc_m_n = + conv_to_gemm_transformer + .template MakeCDescriptor_M_N(); + + group_stride_a = args.C_; + group_stride_b = args.K_ * args.C_ * + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); + group_stride_c = args.K_; + } + + template < + typename InLay = typename GroupedConvTraitsType::InLayout, + typename WeiLay = typename GroupedConvTraitsType::WeiLayout, + typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename std::enable_if && + std::is_same_v && + std::is_same_v, + bool>::type = false> + CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args) + { + in_g_n_c_wis_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1]), + static_cast(args.input_spatial_lengths_[2])}; + wei_g_k_c_xs_lengths = {static_cast(args.G_), + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1]), + static_cast(args.filter_spatial_lengths_[2])}; + out_g_n_k_wos_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.K_), + static_cast(args.output_spatial_lengths_[0]), + static_cast(args.output_spatial_lengths_[1]), + static_cast(args.output_spatial_lengths_[2])}; + + conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), + static_cast(args.conv_filter_strides_[1]), + static_cast(args.conv_filter_strides_[2])}; + conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), + static_cast(args.conv_filter_dilations_[1]), + static_cast(args.conv_filter_dilations_[2])}; + input_left_pads = {static_cast(args.input_left_pads_[0]), + static_cast(args.input_left_pads_[1]), + static_cast(args.input_left_pads_[2])}; + input_right_pads = {static_cast(args.input_right_pads_[0]), + static_cast(args.input_right_pads_[1]), + static_cast(args.input_right_pads_[2])}; + + k_batch = args.k_batch; + + GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] * + args.output_spatial_lengths_[2]; + GemmN = args.K_; + GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] * + args.filter_spatial_lengths_[2]; + + in_ptr = args.in_ptr; + wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } + out_ptr = args.out_ptr; + + ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, + wei_g_k_c_xs_lengths, + out_g_n_k_wos_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + a_grid_desc_m_k = + conv_to_gemm_transformer + .template MakeADescriptor_M_K(); + b_grid_desc_n_k = + conv_to_gemm_transformer + .template MakeBDescriptor_N_K(); + c_grid_desc_m_n = + conv_to_gemm_transformer + .template MakeCDescriptor_M_N(); + + group_stride_a = args.C_; + group_stride_b = args.K_ * args.C_ * + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); + group_stride_c = args.K_; + } + + using AGridDescMK = remove_cvref_t())>; + using BGridDescNK = remove_cvref_t())>; + using CGridDescMN = remove_cvref_t())>; + + static constexpr index_t NonSpatialDims = 3; + array in_g_n_c_wis_lengths; + array wei_g_k_c_xs_lengths; + array out_g_n_k_wos_lengths; + + array conv_filter_strides; + array conv_filter_dilations; + array input_left_pads; + array input_right_pads; + + index_t k_batch; + index_t GemmM; + index_t GemmN; + index_t GemmK; + + const void* in_ptr; + const void* wei_ptr; + std::array ds_ptr; + void* out_ptr; + + AGridDescMK a_grid_desc_m_k; + BGridDescNK b_grid_desc_n_k; + CGridDescMN c_grid_desc_m_n; + + long_index_t group_stride_a; + long_index_t group_stride_b; + long_index_t group_stride_c; +}; + +/// @brief The Grouped Convolution Forward kernel template. +/// +/// @paragraph Overview Overview +/// This class provides the grouped convolution forward kernel template. By semantic +/// division of Implicit GEMM algorithm into following parts we achieve flexible, +/// versatile and robust kernel implementation. +/// +/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator() +/// function call operator" which determines the work scope of each workgroup. +/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm. +/// This is the place where each workgroup is loading data from global memory and +/// carrying out dot products. +/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation +/// responsible for storing results to global memory. This is also the place where +/// any additional operator fusion may take place. +/// +/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_ +/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all +/// internal details of those functional parts. You can think of it like both gemm and +/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover +/// the policy is responsible for definition of all necessary data layouts and thread's +/// work distribution. +/// +/// @tparam GroupedConvTraitsType The type of class providing traits for grouped convolution. +/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into +/// the +/// output data tile to be calculated. It determines the +/// workgroup to data relationship (or in other words - which +/// data would be processed and calculated by which workgroup). +/// @tparam GemmPipeline_ The type of class which provides the core part of matrix +/// multiplication. This class should provide implementation of +/// data loading from global memory and performing block-wise +/// matrix multiplication. You can think of it as a work done by +/// single workgroup point of view. +/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix +/// multiplication implementation. It is responsible for storing +/// results calculated by @ref GemmPipeline_ "GemmPipeline" to +/// the output C tensor in global memory. +template +struct GroupedConvolutionForwardKernel +{ + static constexpr index_t NDimSpatial = GroupedConvTraitsType::NDimSpatial; + static constexpr ConvolutionSpecialization ConvSpecialization = + GroupedConvTraitsType::ConvSpecialization; + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using GemmALayout = remove_cvref_t; + using GemmBLayout = remove_cvref_t; + using GemmCLayout = remove_cvref_t; + + using InLayout = remove_cvref_t; + using WeiLayout = remove_cvref_t; + using OutLayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + + using GemmDsLayout = remove_cvref_t; + + static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; + + static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + + using InDataType = remove_cvref_t; + using WeiDataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + // Below type is actually accumulation data type - the output of block GEMM. + using OutDataType = remove_cvref_t; + + using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs; + + // TODO: Enable this + static constexpr bool IsSplitKSupported = false; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto I3 = number<3>(); + + static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK, + "Not supported!"); + static_assert(std::is_same_v, "Not supported!"); + static_assert(std::is_same_v, "Not supported!"); + static_assert(std::is_same_v, "Not supported!"); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "grouped_convolution_forward", gemm_prec_str, GemmPipeline::GetName()); + // clang-format on + } + + CK_TILE_HOST static constexpr auto GridSize(const GroupedConvHostArgs& args) + { + const index_t GemmM = args.N_ * std::accumulate(args.output_spatial_lengths_.begin(), + args.output_spatial_lengths_.end(), + 1, + std::multiplies()); + const index_t GemmN = args.K_; + return dim3(TilePartitioner::GridSize(GemmM, GemmN), args.G_, args.k_batch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + + CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized + MakeKernelArgs(const GroupedConvHostArgs& hostArgs) + { + return GroupedConvFwdKernelArgsSpecialized(hostArgs); + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvFwdKernelArgsSpecialized& kargs) + { + if constexpr((EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value) || + !IsSplitKSupported) + { + if(kargs.k_batch != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + } + return false; + } + } + + const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}]; + const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}]; + + // check ConvolutionSpecialization + if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3]; + const index_t ConvStride = kargs.conv_filter_strides[i]; + const index_t LeftPad = kargs.input_left_pads[i]; + const index_t RightPad = kargs.input_right_pads[i]; + + if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3]; + const index_t LeftPad = kargs.input_left_pads[i]; + const index_t RightPad = kargs.input_right_pads[i]; + + if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3) + { + if(ConvC != 1) + { + return false; + } + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3]; + + if(filter_spatial_dim != I3) + { + return false; + } + } + } + + namespace ctc = tensor_layout::convolution; + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v) + { + // Check access per C + if(ConvC % GemmPipeline::GetVectorSizeA() != 0) + { + CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!"); + return false; + } + } + else + { + CK_TILE_ERROR("Not supported input layout!"); + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(ConvC % GemmPipeline::GetVectorSizeB() != 0) + { + CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!"); + return false; + } + } + else + { + CK_TILE_ERROR("Not supported weight layout!"); + return false; + } + + // check vector access of E + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(ConvK % EpiloguePipeline::GetVectorSizeC() != 0) + { + CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!"); + return false; + } + } + else + { + CK_TILE_ERROR("Not supported output layout!"); + return false; + } + + return true; + } + + template + CK_TILE_DEVICE static auto + MakeGemmTensorViews(const InDataType* a_ptr, + const WeiDataType* b_ptr, + const std::array& ds_ptr, + OutDataType* c_ptr, + const GroupedConvFwdKernelArgsSpecialized& kargs) + { + static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); + static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!"); + const auto& a_tensor_view = [&]() { + return make_tensor_view(a_ptr, kargs.a_grid_desc_m_k); + }(); + + const auto& b_tensor_view = [&]() { + return make_tensor_view(b_ptr, kargs.b_grid_desc_n_k); + }(); + + // TODO: enable vector write for C in ColMajor + const auto& c_tensor_view = [&]() { + return make_tensor_view(c_ptr, kargs.c_grid_desc_m_n); + }(); + + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + static_assert(std::is_same_v, OutLayout>, + "Not supported!"); + static_assert(std::is_same_v, + "Not supported!"); + static_assert(std::is_same_v, OutDataType>, + "Not supported!"); + + return make_tensor_view( + static_cast(ds_ptr[i]), kargs.c_grid_desc_m_n); + }, + number{}); + + return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) + { + const auto& a_pad_view = [&]() { + const auto& a_tensor_view = views.at(I0); + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + const auto& b_pad_view = [&]() { + const auto& b_tensor_view = views.at(I1); + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + const auto& ds_tensor_view = views.at(I2); + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + }, + number{}); + + const auto& c_pad_view = [&]() { + const auto& c_tensor_view = views.at(I3); + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view); + } + + template + CK_TILE_DEVICE static auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) + { + const auto& a_pad_view = views.at(I0); + const auto& b_pad_view = views.at(I1); + const auto& ds_pad_view = views.at(I2); + const auto& c_pad_view = views.at(I3); + + const auto& a_block_window = [&]() { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); + }(); + + const auto& b_block_window = [&]() { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); + }(); + + const auto ds_block_window = generate_tuple( + [&](auto i) { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, i_n}); + }, + number{}); + + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param kargs Grouped Convolution Forward kernel arguments + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void RunGemm(const InDataType* a_ptr, + const WeiDataType* b_ptr, + const std::array& ds_ptr, + OutDataType* c_ptr, + void* smem_ptr_0, + const GroupedConvFwdKernelArgsSpecialized& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + a_ptr, b_ptr, ds_ptr, c_ptr, kargs); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = + __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK)); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, smem_ptr_0); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The starting pointer of 1st shared memory block. + * @param smem_ptr_1 The starting pointer of 2nd shared memory block. + * @param kargs Grouped Convolution Forward kernel arguments + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void RunGemm2LDS(const InDataType* a_ptr, + const WeiDataType* b_ptr, + const std::array& ds_ptr, + OutDataType* c_ptr, + void* __restrict__ smem_ptr_0, + void* __restrict__ smem_ptr_1, + const GroupedConvFwdKernelArgsSpecialized& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + a_ptr, b_ptr, ds_ptr, c_ptr, kargs); + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = + __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK)); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0, smem_ptr_1); + } + + CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const + { + const auto blockIdX = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto [iM, iN] = + TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const auto blockIdY = __builtin_amdgcn_readfirstlane(blockIdx.y); + const auto group_offset_a = __builtin_amdgcn_readfirstlane(kargs.group_stride_a * blockIdY); + const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY); + const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY); + + // options + const InDataType* a_ptr = static_cast(kargs.in_ptr) + group_offset_a; + const WeiDataType* b_ptr = static_cast(kargs.wei_ptr) + group_offset_b; + OutDataType* c_ptr = static_cast(kargs.out_ptr) + group_offset_c; + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + __shared__ char smem_ptr_1[GetSmemSize()]; + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm2LDS( + a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, i_m, i_n); + } + } + else + { + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n); + } + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp b/include/ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp new file mode 100644 index 0000000000..4cbc5c506a --- /dev/null +++ b/include/ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck_tile { + +enum struct ConvolutionSpecialization +{ + Default, + Filter1x1Pad0, + Filter1x1Stride1Pad0, + Filter3x3, +}; + +CK_TILE_HOST std::string getConvSpecializationString(const ConvolutionSpecialization& s) +{ + switch(s) + { + case ConvolutionSpecialization::Default: return "Default"; + case ConvolutionSpecialization::Filter1x1Pad0: return "Filter1x1Pad0"; + case ConvolutionSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; + case ConvolutionSpecialization::Filter3x3: return "Filter3x3"; + default: return "Unrecognized specialization!"; + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp new file mode 100644 index 0000000000..4b7cb3c895 --- /dev/null +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/convolution_parameter.hpp" + +namespace ck_tile { + +/// @brief The Grouped Conv kernel host arguments. +/// +/// @par Overview +/// This structure is passed to Grouped Convolution Kernels when creating kernel +/// arguments object. It contain all necessary information required to +/// build proper kernel argument and launch kernel on GPU. +struct GroupedConvHostArgs : public conv::ConvParam +{ + CK_TILE_HOST GroupedConvHostArgs() = delete; + CK_TILE_HOST GroupedConvHostArgs(ConvParam conv_param, + const void* in_ptr_, + const void* wei_ptr_, + const std::vector ds_ptr_, + void* out_ptr_, + index_t k_batch_) + : conv::ConvParam(conv_param), + in_ptr(in_ptr_), + wei_ptr(wei_ptr_), + ds_ptr(ds_ptr_), + out_ptr(out_ptr_), + k_batch(k_batch_) + { + } + + const void* in_ptr; + const void* wei_ptr; + const std::vector ds_ptr; + void* out_ptr; + index_t k_batch; +}; + +template +struct GroupedConvTraits +{ + private: + static constexpr auto generate_implicit_gemm_layout() + { + return generate_tuple([](auto) { return ck_tile::tensor_layout::gemm::RowMajor{}; }, + number{}); + } + + public: + static constexpr index_t NDimSpatial = NDimSpatial_; + static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_; + using InLayout = InLayout_; + using WeiLayout = WeiLayout_; + using DsLayout = DsLayout_; + using OutLayout = OutLayout_; + using GroupedConvImplicitGemmTraits = TileGemmTraits; + static constexpr index_t NumDTensor = DsLayout::size(); + using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout()); +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp new file mode 100644 index 0000000000..c468ae4398 --- /dev/null +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp @@ -0,0 +1,1432 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck_tile/core.hpp" +#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" + +namespace ck_tile { + +template +struct TransformConvFwdToGemm +{ + private: + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + static constexpr auto I3 = number<3>{}; + static constexpr auto I4 = number<4>{}; + static constexpr auto I5 = number<5>{}; +#if 0 // TODO: Enable these functionalities + template + static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths, + const ConvDimsType& strides, + index_t i) + { + long_index_t acc = 1; + for(; i < (NDimSpatial + 3); i++) + { + acc += + static_cast(lengths[i] - I1) * static_cast(strides[i]); + } + + return acc; + } + + template + static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& a_g_n_c_wis_strides, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& c_g_n_k_wos_strides) + { + const long_index_t a_element_space_size = + calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1); + const long_index_t c_element_space_size = + calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1); + const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType), + c_element_space_size * sizeof(CDataType)); + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const IndexType N = a_g_n_c_wis_lengths[I1]; + + if(element_space_size > TwoGB) + { + // Minimum divisor of N to not exceed 2GB + const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB); + + if(divisor <= static_cast(N)) + { + // Find least divisor of N larger than element_space_size / TwoGB + // Iterate up to sqrt(N). There are no divisors above this value. + for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N; + least_divisor++) + { + if(N % least_divisor == 0) + { + return N / least_divisor; + } + } + // Not found, process one Convolution N per block + return 1; + } + else + { + // Split Convolution's N dimension into N workgroups. However + // this still might not result in sufficiently small tensor, + // but at least later on we could divide the image as well. + return 1; + } + } + else + { + // Split N is not needed. + return N; + } + } +#endif + + public: + CK_TILE_HOST constexpr TransformConvFwdToGemm() {} + + template + CK_TILE_HOST + TransformConvFwdToGemm(const TransformConvFwdToGemmBase& transform_conv_fwd_to_gemm_base) + : G_{static_cast(transform_conv_fwd_to_gemm_base.G_)}, + N_{static_cast(transform_conv_fwd_to_gemm_base.N_)}, + Di_{static_cast(transform_conv_fwd_to_gemm_base.Di_)}, + Hi_{static_cast(transform_conv_fwd_to_gemm_base.Hi_)}, + Wi_{static_cast(transform_conv_fwd_to_gemm_base.Wi_)}, + Do_{static_cast(transform_conv_fwd_to_gemm_base.Do_)}, + Ho_{static_cast(transform_conv_fwd_to_gemm_base.Ho_)}, + Wo_{static_cast(transform_conv_fwd_to_gemm_base.Wo_)}, + Z_{static_cast(transform_conv_fwd_to_gemm_base.Z_)}, + Y_{static_cast(transform_conv_fwd_to_gemm_base.Y_)}, + X_{static_cast(transform_conv_fwd_to_gemm_base.X_)}, + K_{static_cast(transform_conv_fwd_to_gemm_base.K_)}, + C_{static_cast(transform_conv_fwd_to_gemm_base.C_)}, + ConvStrideD_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideD_)}, + ConvStrideH_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideH_)}, + ConvStrideW_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideW_)}, + ConvDilationD_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationD_)}, + ConvDilationH_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationH_)}, + ConvDilationW_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationW_)}, + InLeftPadD_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadD_)}, + InLeftPadH_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadH_)}, + InLeftPadW_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadW_)}, + InRightPadD_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadD_)}, + InRightPadH_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadH_)}, + InRightPadW_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadW_)}, + ZYX_{static_cast(transform_conv_fwd_to_gemm_base.ZYX_)} + { + } + + template ::type = false> + CK_TILE_HOST TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : G_{a_g_n_c_wis_lengths[I0]}, + Di_{I1}, + Hi_{I1}, + Wi_{a_g_n_c_wis_lengths[I3]}, + Do_{I1}, + Ho_{I1}, + Wo_{c_g_n_k_wos_lengths[I3]}, + Z_{I1}, + Y_{I1}, + X_{b_g_k_c_xs_lengths[I3]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + ConvStrideD_{I1}, + ConvStrideH_{I1}, + ConvStrideW_{conv_filter_strides[I0]}, + ConvDilationD_{I1}, + ConvDilationH_{I1}, + ConvDilationW_{conv_filter_dilations[I0]}, + InLeftPadD_{I0}, + InLeftPadH_{I0}, + InLeftPadW_{input_left_pads[I0]}, + InRightPadD_{I0}, + InRightPadH_{I0}, + InRightPadW_{input_right_pads[I0]}, + ZYX_{X_} + { + static_assert(std::is_same_v> || + std::is_same_v>); + static_assert(std::is_same_v> || + std::is_same_v>); +#if 0 // TODO: Enable these functionalities + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } +#endif + N_ = c_g_n_k_wos_lengths[I1]; + } + + template ::type = false> + CK_TILE_HOST TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : G_{a_g_n_c_wis_lengths[I0]}, + Di_{I1}, + Hi_{a_g_n_c_wis_lengths[I3]}, + Wi_{a_g_n_c_wis_lengths[I4]}, + Do_{I1}, + Ho_{c_g_n_k_wos_lengths[I3]}, + Wo_{c_g_n_k_wos_lengths[I4]}, + Z_{I1}, + Y_{b_g_k_c_xs_lengths[I3]}, + X_{b_g_k_c_xs_lengths[I4]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + ConvStrideD_{I1}, + ConvStrideH_{conv_filter_strides[I0]}, + ConvStrideW_{conv_filter_strides[I1]}, + ConvDilationD_{I1}, + ConvDilationH_{conv_filter_dilations[I0]}, + ConvDilationW_{conv_filter_dilations[I1]}, + InLeftPadD_{I0}, + InLeftPadH_{input_left_pads[I0]}, + InLeftPadW_{input_left_pads[I1]}, + InRightPadD_{I0}, + InRightPadH_{input_right_pads[I0]}, + InRightPadW_{input_right_pads[I1]}, + ZYX_{Y_ * X_} + { + static_assert(std::is_same_v> || + std::is_same_v>); + static_assert(std::is_same_v> || + std::is_same_v>); +#if 0 // TODO: Enable these functionalities + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } +#endif + N_ = c_g_n_k_wos_lengths[I1]; + } + + template ::type = false> + CK_TILE_HOST TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : G_{a_g_n_c_wis_lengths[I0]}, + Di_{a_g_n_c_wis_lengths[I3]}, + Hi_{a_g_n_c_wis_lengths[I4]}, + Wi_{a_g_n_c_wis_lengths[I5]}, + Do_{c_g_n_k_wos_lengths[I3]}, + Ho_{c_g_n_k_wos_lengths[I4]}, + Wo_{c_g_n_k_wos_lengths[I5]}, + Z_{b_g_k_c_xs_lengths[I3]}, + Y_{b_g_k_c_xs_lengths[I4]}, + X_{b_g_k_c_xs_lengths[I5]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + ConvStrideD_{conv_filter_strides[I0]}, + ConvStrideH_{conv_filter_strides[I1]}, + ConvStrideW_{conv_filter_strides[I2]}, + ConvDilationD_{conv_filter_dilations[I0]}, + ConvDilationH_{conv_filter_dilations[I1]}, + ConvDilationW_{conv_filter_dilations[I2]}, + InLeftPadD_{input_left_pads[I0]}, + InLeftPadH_{input_left_pads[I1]}, + InLeftPadW_{input_left_pads[I2]}, + InRightPadD_{input_right_pads[I0]}, + InRightPadH_{input_right_pads[I1]}, + InRightPadW_{input_right_pads[I2]}, + ZYX_{Z_ * Y_ * X_} + { + static_assert(std::is_same_v> || + std::is_same_v>); + static_assert(std::is_same_v> || + std::is_same_v>); +#if 0 // TODO: Enable these functionalities + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } +#endif + N_ = c_g_n_k_wos_lengths[I1]; + } + +#if 0 // TODO: Enable these functionalities + __host__ bool AreDescriptorsSmallerThan2GB() const + { + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const long_index_t in_desc_space_size = + I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ + + (Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_; + const long_index_t out_desc_space_size = + I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ + + (Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_; + + bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB; + bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB; + + return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB; + } + + __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base, + CDataType* c_grid_ptr_base) const + { + // Create copies + auto conv_to_gemm_transformer_left = *this; + auto conv_to_gemm_transformer_right = *this; + IndexType a_right_offset = 0; + IndexType c_right_offset = 0; + // Calculate real filter size + const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1; + const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1; + const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1; + // Calculate start position in input for right tensor + const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_; + const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_; + const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_; + // Calculate last position in input for left tensor + const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff; + const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff; + const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff; + // Allow to split if whole left padding will be in left tensor and right padding in right + // tensor + const bool is_possible_to_split_d = Do_ != 1 && + di_right_transformer_start_idx > InLeftPadD_ && + di_left_transformer_end_idx <= (InLeftPadD_ + Di_); + const bool is_possible_to_split_h = Ho_ != 1 && + hi_right_transformer_start_idx > InLeftPadH_ && + hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_); + const bool is_possible_to_split_w = Wo_ != 1 && + wi_right_transformer_start_idx > InLeftPadW_ && + wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_); + + if(is_possible_to_split_d) + { + // Apply new sizes + // Split output on half + conv_to_gemm_transformer_left.Do_ = Do_ / 2; + conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2; + // Assign left padding to left convolution + conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_; + conv_to_gemm_transformer_right.InLeftPadD_ = 0; + // Assign right padding to right convolution + conv_to_gemm_transformer_left.InRightPadD_ = 0; + conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_; + // Calculate new input size + conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_; + conv_to_gemm_transformer_right.Di_ = + math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_), + (conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff); + ; + // Calcualte offsets + a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_; + c_right_offset = (Do_ / 2) * DoStride_; + } + else if(is_possible_to_split_h) + { + conv_to_gemm_transformer_left.Ho_ = Ho_ / 2; + conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2; + + conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_; + conv_to_gemm_transformer_right.InLeftPadH_ = 0; + + conv_to_gemm_transformer_left.InRightPadH_ = 0; + conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_; + + conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_; + conv_to_gemm_transformer_right.Hi_ = + math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_), + (conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff); + a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_; + c_right_offset = (Ho_ / 2) * HoStride_; + } + else if(is_possible_to_split_w) + { + conv_to_gemm_transformer_left.Wo_ = Wo_ / 2; + conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2; + + conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_; + conv_to_gemm_transformer_right.InLeftPadW_ = 0; + + conv_to_gemm_transformer_left.InRightPadW_ = 0; + conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_; + + conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_; + conv_to_gemm_transformer_right.Wi_ = + math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_), + (conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff); + + a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_; + c_right_offset = (Wo_ / 2) * WoStride_; + } + // Return left transform, right transformer, right offset to Input and right offset to + // Output + return ck_tile::make_tuple(conv_to_gemm_transformer_left, + conv_to_gemm_transformer_right, + a_grid_ptr_base + a_right_offset, + c_grid_ptr_base + c_right_offset); + } +#endif + // TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as + // properties + template , + bool>::type = false> + CK_TILE_HOST auto MakeADescriptor_M_K() const + { + IndexType WiStride_ = G_ * C_; + IndexType CStrideTensorA_ = 1; + IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_; + IndexType GStrideTensorA_ = C_; + + if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wo_, C_), + make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_)); + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wo_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3) + { + if constexpr(NumGroupsToMerge == 1) + { + + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_), make_tuple(NStrideTensorA_, WiStride_)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(number<3>{})), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(number<3>{})), + make_tuple(sequence<0, 2, 3>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, C_), + make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_)); + + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + return transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, C_), + make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_merge_transform(make_tuple(X_, C_))), + make_tuple(sequence<0, 2>{}, sequence<1, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}, sequence<4>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(X_, C_))), + make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + } + + template , + bool>::type = false> + CK_TILE_HOST auto MakeADescriptor_M_K() const + + { + IndexType HiStride_ = Wi_ * G_ * C_; + IndexType WiStride_ = G_ * C_; + IndexType CStrideTensorA_ = 1; + IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_; + IndexType GStrideTensorA_ = C_; + + if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Ho_, Wo_, C_), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, C_), + make_tuple( + NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_), make_tuple(NStrideTensorA_, HiStride_, WiStride_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_))), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_merge_transform(make_tuple(number<3>{}, number<3>{}))), + make_tuple(sequence<0, 2, 4>{}, sequence<1, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_)); + + const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hip_wip_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(number<3>{}, number<3>{}))), + make_tuple(sequence<0, 2, 4, 5>{}, sequence<1, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_)); + + const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + return transform_tensor_descriptor( + in_n_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple( + NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + const auto in_n_ho_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + return transform_tensor_descriptor( + in_n_ho_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_merge_transform(make_tuple(Y_, X_, C_))), + make_tuple(sequence<0, 2, 4>{}, sequence<1, 3, 5>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + + const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple( + NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hip_wip_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5>{}, + sequence<6>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(Y_, X_, C_))), + make_tuple(sequence<0, 2, 4, 5>{}, sequence<1, 3, 6>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + } + + template , + bool>::type = false> + CK_TILE_HOST auto MakeADescriptor_M_K() const + + { + IndexType DiStride_ = Hi_ * Wi_ * G_ * C_; + IndexType HiStride_ = Wi_ * G_ * C_; + IndexType WiStride_ = G_ * C_; + IndexType CStrideTensorA_ = 1; + IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_; + IndexType GStrideTensorA_ = C_; + + if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Do_, Ho_, Wo_, C_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, + DiStride_, + HiStride_, + WiStride_, + GStrideTensorA_, + CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2, 3, 4>{}, sequence<5>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(number<3>{}, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_))), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple( + sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5, 6>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_merge_transform(make_tuple(number<3>{}, number<3>{}, number<3>{}))), + make_tuple(sequence<0, 2, 4, 6>{}, sequence<1, 3, 5>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(number<3>{}, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(number<3>{}, number<3>{}, number<3>{}))), + make_tuple(sequence<0, 2, 4, 6, 7>{}, sequence<1, 3, 5>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_)); + + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Do_), make_tuple(ConvStrideD_)), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + return transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, + DiStride_, + HiStride_, + WiStride_, + GStrideTensorA_, + CStrideTensorA_)); + + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Do_), make_tuple(ConvStrideD_)), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{})); + + return transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2, 3, 4>{}, sequence<5>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Z_, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_merge_transform(make_tuple(Z_, Y_, X_, C_))), + make_tuple(sequence<0, 2, 4, 6>{}, sequence<1, 3, 5, 7>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, + DiStride_, + HiStride_, + WiStride_, + GStrideTensorA_, + CStrideTensorA_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Z_, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{}, + sequence<8>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(Z_, Y_, X_, C_))), + make_tuple(sequence<0, 2, 4, 6, 7>{}, sequence<1, 3, 5, 8>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + } + + template < + typename BLayout, + typename std::enable_if || + std::is_same_v || + std::is_same_v, + bool>::type = false> + CK_TILE_HOST auto MakeBDescriptor_N_K() const + { + IndexType CStrideTensorB_ = 1; + IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_; + IndexType GStrideTensorB_ = K_ * Z_ * Y_ * X_ * C_; + + if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3) + { + using FilterSizeNumType = + std::conditional_t, + std::conditional_t, number<27>>>; + + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor_packed(make_tuple(K_, FilterSizeNumType{})); + } + else + { + + const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(K_, NumGroupsToMerge, FilterSizeNumType{}), + make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_)); + return transform_tensor_descriptor( + wei_gemmn_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)), + make_pass_through_transform(FilterSizeNumType{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else + { + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor_packed(make_tuple(K_, ZYX_ * C_)); + } + else + { + const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(K_, NumGroupsToMerge, ZYX_ * C_), + make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_)); + return transform_tensor_descriptor( + wei_gemmn_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)), + make_pass_through_transform(ZYX_ * C_)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + } + + template , + bool>::type = false> + CK_TILE_HOST auto MakeCDescriptor_M_N() const + { + IndexType WoStride_ = G_ * K_; + IndexType KStrideTensorC_ = 1; + IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_; + IndexType GStrideTensorC_ = K_; + + const IndexType NDoHoWo = N_ * Wo_; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_), + make_tuple(WoStride_, KStrideTensorC_)); + } + else + { + const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wo_, NumGroupsToMerge, K_, 1), + make_tuple( + NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + nhwo_groups_k_1_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(sequence<0, 1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + // We need only matrices from diagonal. X_or returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || + NumGroupsToMerge == 32 || NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_pass_through_transform(NDoHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), + make_merge_transform(make_tuple(K_, NumGroupsToMerge))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + + template , + bool>::type = false> + CK_TILE_HOST auto MakeCDescriptor_M_N() const + { + IndexType HoStride_ = Wo_ * G_ * K_; + IndexType WoStride_ = G_ * K_; + IndexType KStrideTensorC_ = 1; + IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_; + IndexType GStrideTensorC_ = K_; + + const IndexType NDoHoWo = N_ * Ho_ * Wo_; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_), + make_tuple(WoStride_, KStrideTensorC_)); + } + else + { + const auto nhwo_groups_k_1_desc = + make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, K_, 1), + make_tuple(NStrideTensorC_, + HoStride_, + WoStride_, + GStrideTensorC_, + KStrideTensorC_, + GStrideTensorC_)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + nhwo_groups_k_1_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}, sequence<4>{}, sequence<5>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + // We need only matrices from diagonal. X_or returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || + NumGroupsToMerge == 32 || NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_pass_through_transform(NDoHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), + make_merge_transform(make_tuple(K_, NumGroupsToMerge))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + + template , + bool>::type = false> + CK_TILE_HOST auto MakeCDescriptor_M_N() const + { + IndexType DoStride_ = Ho_ * Wo_ * G_ * K_; + IndexType HoStride_ = Wo_ * G_ * K_; + IndexType WoStride_ = G_ * K_; + IndexType KStrideTensorC_ = 1; + IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_; + IndexType GStrideTensorC_ = K_; + + const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_), + make_tuple(WoStride_, KStrideTensorC_)); + } + else + { + const auto nhwo_groups_k_1_desc = + make_naive_tensor_descriptor(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, K_, 1), + make_tuple(NStrideTensorC_, + DoStride_, + HoStride_, + WoStride_, + GStrideTensorC_, + KStrideTensorC_, + GStrideTensorC_)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + nhwo_groups_k_1_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}, sequence<5>{}, sequence<6>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + // We need only matrices from diagonal. X_or returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || + NumGroupsToMerge == 32 || NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_pass_through_transform(NDoHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), + make_merge_transform(make_tuple(K_, NumGroupsToMerge))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + + IndexType G_, N_; + IndexType Di_, Hi_, Wi_; + IndexType Do_, Ho_, Wo_; + IndexType Z_, Y_, X_; + IndexType K_, C_; + IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_; + IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_; + IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_; + IndexType InRightPadD_, InRightPadH_, InRightPadW_; + IndexType ZYX_; +}; + +} // namespace ck_tile