diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index ccdfb0f6fb..f9ded8a029 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,8 +1,8 @@ -* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli +* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @shumway @vidyasagar-amd # Documentation files -docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli -*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli -*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli -.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli +docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd +*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd +*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd +.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd # Header directory for Doxygen documentation -library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli +library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd diff --git a/Jenkinsfile b/Jenkinsfile index f9d7feb77c..b2fda68b70 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -343,15 +343,8 @@ def cmake_build(Map conf=[:]){ def build_cmd def execute_cmd = conf.get("execute_cmd", "") if(!setup_args.contains("NO_CK_BUILD")){ - if (setup_args.contains("gfx9") && params.NINJA_BUILD_TRACE){ - echo "running ninja build trace" - setup_cmd = conf.get("setup_cmd", """${cmake_envs} cmake -G Ninja ${setup_args} -DCMAKE_CXX_FLAGS=" -O3 -ftime-trace " .. """) - build_cmd = conf.get("build_cmd", "${build_envs} ninja -j${nt} ${config_targets}") - } - else{ - setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") - build_cmd = conf.get("build_cmd", "${build_envs} make -j${nt} ${config_targets}") - } + setup_cmd = conf.get("setup_cmd", """${cmake_envs} cmake -G Ninja ${setup_args} -DCMAKE_CXX_FLAGS=" -O3 -ftime-trace " .. """) + build_cmd = conf.get("build_cmd", "${build_envs} ninja -j${nt} ${config_targets}") cmd = conf.get("cmd", """ ${setup_cmd} ${build_cmd} @@ -379,7 +372,12 @@ def cmake_build(Map conf=[:]){ archiveArtifacts "clang_build_analysis.log" // do not run unit tests when building instances only if(!params.BUILD_INSTANCES_ONLY){ - sh "ninja check" + if (!params.RUN_ALL_UNIT_TESTS){ + sh "../script/launch_tests.sh" + } + else{ + sh "ninja check" + } } if(params.BUILD_INSTANCES_ONLY){ // build deb packages @@ -393,7 +391,12 @@ def cmake_build(Map conf=[:]){ else{ // run unit tests unless building library for all targets if (!params.BUILD_INSTANCES_ONLY){ - sh "make check" + if (!params.RUN_ALL_UNIT_TESTS){ + sh "../script/launch_tests.sh" + } + else{ + sh "ninja check" + } } } } @@ -793,10 +796,10 @@ def process_results(Map conf=[:]){ } //launch develop branch daily jobs -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=false - 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true - 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true - 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=false;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true + 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true + 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true + 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 13 * * * % BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false''' : "" @@ -859,8 +862,8 @@ pipeline { description: "Run the cppcheck static analysis (default: OFF)") booleanParam( name: "RUN_PERFORMANCE_TESTS", - defaultValue: true, - description: "Run the performance tests (default: ON)") + defaultValue: false, + description: "Run the performance tests (default: OFF)") booleanParam( name: "RUN_GROUPED_CONV_LARGE_CASES_TESTS", defaultValue: false, @@ -913,6 +916,10 @@ pipeline { name: "RUN_INDUCTOR_TESTS", defaultValue: true, description: "Run inductor codegen tests (default: ON)") + booleanParam( + name: "RUN_ALL_UNIT_TESTS", + defaultValue: false, + description: "Run all unit tests (default: OFF)") } environment{ dbuser = "${dbuser}" @@ -1025,7 +1032,7 @@ pipeline { { when { beforeAgent true - expression { params.RUN_CODEGEN_TESTS.toBoolean() } + expression { params.RUN_CODEGEN_TESTS.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } } agent{ label rocmnode("gfx90a")} environment{ diff --git a/example/ck_tile/19_gemm_multi_d/CMakeLists.txt b/example/ck_tile/19_gemm_multi_d/CMakeLists.txt index e2e68b325a..4ecfec7ccf 100644 --- a/example/ck_tile/19_gemm_multi_d/CMakeLists.txt +++ b/example/ck_tile/19_gemm_multi_d/CMakeLists.txt @@ -1 +1,6 @@ add_executable(tile_example_gemm_multi_d_fp16 EXCLUDE_FROM_ALL gemm_multi_d_fp16.cpp) +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() +target_compile_options(tile_example_gemm_multi_d_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/20_grouped_convolution/CMakeLists.txt b/example/ck_tile/20_grouped_convolution/CMakeLists.txt new file mode 100644 index 0000000000..00cb0ab9e5 --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/CMakeLists.txt @@ -0,0 +1,4 @@ +add_executable(tile_example_grouped_conv_fwd EXCLUDE_FROM_ALL grouped_convolution_forward.cpp) +set(EXAMPLE_CONV_COMPILE_OPTIONS) +list(APPEND EXAMPLE_CONV_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) +target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp new file mode 100644 index 0000000000..685fdccde2 --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "grouped_convolution_utils.hpp" + +template , + typename DsLayout = ck_tile::tuple<>, + typename CDEElementWise = ck_tile::element_wise::PassThrough> +float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s) +{ + constexpr int kBlockPerCu = 1; + + constexpr ck_tile::index_t M_Tile = 64; + constexpr ck_tile::index_t N_Tile = 64; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr ck_tile::index_t VectorSizeA = 8; + constexpr ck_tile::index_t VectorSizeB = 8; + constexpr ck_tile::index_t VectorSizeC = 8; + + // Implicit GEMM Traits + using CodegenShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + using GroupedConvTraitsType = + ck_tile::GroupedConvTraits; + using CodegenPipelineProblem = + ck_tile::GemmPipelineProblem; + using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + using ConvEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GroupedConvolutionForwardKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << CodegenPipeline::GetVectorSizeA() + << ", Vector size B: " << CodegenPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + + return Run(ck_tile::integral_constant{}); +} + +#include "run_grouped_convolution_example.inc" + +template +int run_grouped_conv_fwd_example_prec_type( + std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[]) +{ + using NWGC = ck_tile::tensor_layout::convolution::NWGC; + using NHWGC = ck_tile::tensor_layout::convolution::NHWGC; + using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC; + + using GKXC = ck_tile::tensor_layout::convolution::GKXC; + using GKYXC = ck_tile::tensor_layout::convolution::GKYXC; + using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC; + + using NWGK = ck_tile::tensor_layout::convolution::NWGK; + using NHWGK = ck_tile::tensor_layout::convolution::NHWGK; + using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK; + + if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK") + { + return run_grouped_conv_fwd_example_with_layouts{}, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NWGC{}, GKXC{}, NWGK{}); + } + else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") + { + return run_grouped_conv_fwd_example_with_layouts{}, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NHWGC{}, GKYXC{}, NHWGK{}); + } + else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "GKZYXC") + { + return run_grouped_conv_fwd_example_with_layouts{}, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{}); + } + else + { + throw std::runtime_error("Unsupported memory layout!"); + } +} + +int run_grouped_conv_fwd_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + std::string data_type = arg_parser.get_str("prec"); + std::string in_layout = arg_parser.get_str("in_layout"); + std::string wei_layout = arg_parser.get_str("weight_layout"); + std::string out_layout = arg_parser.get_str("out_layout"); + + if(data_type == "fp16") + { + return run_grouped_conv_fwd_example_prec_type( + in_layout, wei_layout, out_layout, argc, argv); + } + else if(data_type == "bf16") + { + return run_grouped_conv_fwd_example_prec_type( + in_layout, wei_layout, out_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } +} + +int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_example(argc, argv); } diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp new file mode 100644 index 0000000000..cc8d365b18 --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp @@ -0,0 +1,108 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" + +ck_tile::index_t fill_spatial_dimensions(std::vector& filter_spatial_lengths, + std::vector& image_spatial_lengths, + std::vector& strides, + std::vector& dilations, + std::vector& lpads, + std::vector& rpads, + ck_tile::ArgParser& arg_parser) +{ + + constexpr ck_tile::index_t non_sp_dims = 3; + const ck_tile::index_t n_dim_sp = arg_parser.get_str("in_layout").size() - non_sp_dims; + + if(!(n_dim_sp >= 1 && n_dim_sp <= 3)) + { + throw std::runtime_error("Wrong layout!\n"); + } + + if(n_dim_sp == 3) + { + filter_spatial_lengths.push_back(arg_parser.get_int("z")); + image_spatial_lengths.push_back(arg_parser.get_int("d")); + strides.push_back(arg_parser.get_int("stride_d")); + dilations.push_back(arg_parser.get_int("dilation_d")); + lpads.push_back(arg_parser.get_int("lpad_d")); + rpads.push_back(arg_parser.get_int("rpad_d")); + } + if(n_dim_sp >= 2) + { + filter_spatial_lengths.push_back(arg_parser.get_int("y")); + image_spatial_lengths.push_back(arg_parser.get_int("h")); + strides.push_back(arg_parser.get_int("stride_h")); + dilations.push_back(arg_parser.get_int("dilation_h")); + lpads.push_back(arg_parser.get_int("lpad_h")); + rpads.push_back(arg_parser.get_int("rpad_h")); + } + filter_spatial_lengths.push_back(arg_parser.get_int("x")); + image_spatial_lengths.push_back(arg_parser.get_int("w")); + strides.push_back(arg_parser.get_int("stride_w")); + dilations.push_back(arg_parser.get_int("dilation_w")); + lpads.push_back(arg_parser.get_int("lpad_w")); + rpads.push_back(arg_parser.get_int("rpad_w")); + + return n_dim_sp; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("g", "2", "group dimension") + .insert("n", "32", "n dimension") + .insert("k", "32", "k dimension") + .insert("c", "32", "c dimension") + + .insert("d", "64", "d dimension") + .insert("h", "64", "h dimension") + .insert("w", "64", "w dimension") + + .insert("z", "4", "z dimension") + .insert("y", "4", "y dimension") + .insert("x", "4", "x dimension") + + .insert("stride_d", "1", "d stride") + .insert("stride_h", "1", "h stride") + .insert("stride_w", "1", "w stride") + + .insert("dilation_d", "1", "d dilation") + .insert("dilation_h", "1", "h dilation") + .insert("dilation_w", "1", "w dilation") + + .insert("lpad_d", "0", "left pad for d dimension") + .insert("lpad_h", "0", "left pad for h dimension") + .insert("lpad_w", "0", "left pad for w dimension") + + .insert("rpad_d", "0", "right pad for d dimension") + .insert("rpad_h", "0", "right pad for h dimension") + .insert("rpad_w", "0", "right pad for w dimension") + + .insert("in_layout", "NHWGC", "Input image layout - NHWGC by default") + .insert("weight_layout", "GKYXC", "Weight layout - GKYXC by default") + .insert("out_layout", "NHWGK", "Output image layout - NHWGK by default") + .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// host API +float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_example.inc new file mode 100644 index 0000000000..ed72eb354d --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_example.inc @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +template +auto calculate_rtol_atol(const ck_tile::index_t GemmK, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(GemmK, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(GemmK, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = + ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_grouped_conv_fwd(ck_tile::GroupedConvHostArgs& args, int n_warmup, int n_repeat) +{ + float ave_time = grouped_conv_fwd( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::size_t flop = args.GetFlops(); + std::size_t num_byte = args.GetByte(); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; + + return ave_time; +} + +template +int run_grouped_conv_fwd_example_with_layouts( + int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using AccDataType = float; + + std::vector filter_spatial_lengths; + std::vector image_spatial_lengths; + std::vector strides; + std::vector dilations; + std::vector lpads; + std::vector rpads; + + const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths, + image_spatial_lengths, + strides, + dilations, + lpads, + rpads, + arg_parser); + + ck_tile::conv::ConvParam conv_param{num_dim_sp, + arg_parser.get_int("g"), + arg_parser.get_int("n"), + arg_parser.get_int("k"), + arg_parser.get_int("c"), + filter_spatial_lengths, + image_spatial_lengths, + strides, + dilations, + lpads, + rpads}; + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + + const auto in_g_n_c_wis_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + const auto wei_g_k_c_xs_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + const auto out_g_n_k_wos_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + ck_tile::HostTensor input(in_g_n_c_wis_desc); + ck_tile::HostTensor weight(wei_g_k_c_xs_desc); + ck_tile::HostTensor output(out_g_n_k_wos_desc); + + if(init_method == 0) + { + ck_tile::FillUniformDistribution{-5.f, 5.f}(input); + ck_tile::FillUniformDistribution{-5.f, 5.f}(weight); + } + else if(init_method == 1) + { + ck_tile::FillMonotonicSeq{}(input); + ck_tile::FillMonotonicSeq{}(weight); + } + else if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(input); + ck_tile::FillUniformDistribution{1.f, 1.f}(weight); + } + else + { + input.SetZero(); + weight.SetZero(); + } + + ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes()); + + input_dev_buf.ToDevice(input.data()); + weight_dev_buf.ToDevice(weight.data()); + output_dev_buf.SetZero(); + + ck_tile::GroupedConvHostArgs args(conv_param, + input_dev_buf.GetDeviceBuffer(), + weight_dev_buf.GetDeviceBuffer(), + {}, + output_dev_buf.GetDeviceBuffer(), + kbatch); + + std::cout << "Run Grouped Conv Fwd kernel" << std::endl; + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "output: " << output.mDesc << std::endl; + + invoke_grouped_conv_fwd(args, n_warmup, n_repeat); + + output_dev_buf.FromDevice(output.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor output_host_ref(out_g_n_k_wos_desc); + output_host_ref.SetZero(); + + ck_tile::reference_grouped_conv_fwd( + input, + weight, + output_host_ref, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_); + const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_); + const float max_accumulated_value = + *std::max_element(output_host_ref.mData.begin(), output_host_ref.mData.end()); + const auto rtol_atol = + calculate_rtol_atol( + GemmK, kbatch, max_accumulated_value); + pass = ck_tile::check_err(output, + output_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + throw std::runtime_error("Unsupported gpu verification !!!"); + } + + return pass; +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 92b859a750..8989060842 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -19,6 +19,7 @@ add_subdirectory(16_batched_gemm) add_subdirectory(17_grouped_gemm) add_subdirectory(18_flatmm) add_subdirectory(19_gemm_multi_d) +add_subdirectory(20_grouped_convolution) add_subdirectory(35_batched_transpose) add_subdirectory(36_copy) add_subdirectory(37_transpose) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 0ec1a95511..12f49aa4e3 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -2784,10 +2784,13 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, #endif } +#if defined(__gfx950__) template __device__ auto amd_transpose_load_to_vgpr(const T* in_ptr) { + static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32), + "We need to have the compatible compiler version to build this instruction"); if constexpr(std::is_same_v, ck_tile::half_t>) { typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t; @@ -2817,6 +2820,7 @@ __device__ auto amd_transpose_load_to_vgpr(const T* in_ptr) static_assert(false, "not implemented"); } } +#endif } // namespace ck_tile diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 53a344c7b0..306d2cdac3 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -2554,6 +2554,44 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, #endif } +#if defined(__gfx950__) +template +__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr) +{ + + static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32), + "We need to have the compatible compiler version to build this instruction"); + if constexpr(std::is_same_v, ck_tile::half_t>) + { + typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t; + __attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr = + reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>( + reinterpret_cast(in_ptr)); + return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr)); + } + else if constexpr(std::is_same_v, ck_tile::bf16_t>) + { + typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t; + __attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr = + reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>( + reinterpret_cast(in_ptr)); + return bit_cast>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr)); + } + else if constexpr(std::is_same_v, ck_tile::fp8_t>) + { + typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t; + __attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr = + reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>( + reinterpret_cast(in_ptr)); + return bit_cast>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr)); + } + else + { + static_assert(false, "not implemented"); + } +} +#endif + } // namespace ck_tile #endif // CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index cd7b7d0a1f..8d19337b86 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -902,8 +902,9 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE constexpr auto - transpose_get(index_t i, index_t linear_offset, bool is_valid_element) const + CK_TILE_DEVICE constexpr auto transpose_get([[maybe_unused]] index_t i, + [[maybe_unused]] index_t linear_offset, + bool is_valid_element) const { // X contains multiple T constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -913,13 +914,16 @@ struct buffer_view, t_per_x, addr_space>( p_data_ + i + linear_offset); +#else + return X{numeric>::zero()}; +#endif } else { diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 44851fec4a..4a9748fcbb 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -27,6 +27,7 @@ #include "ck_tile/host/reference/reference_elementwise.hpp" #include "ck_tile/host/reference/reference_fused_moe.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" +#include "ck_tile/host/reference/reference_grouped_conv_fwd.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_moe_sorting.hpp" diff --git a/include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp b/include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp new file mode 100644 index 0000000000..8a12fdb7e0 --- /dev/null +++ b/include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor& input, + const HostTensor& weight, + HostTensor& output, + std::vector conv_strides, + std::vector conv_dilations, + std::vector in_left_pads, + std::vector) +{ + if(!(input.get_num_of_dimension() == NDimSpatial + 3 && + weight.get_num_of_dimension() == NDimSpatial + 3 && + output.get_num_of_dimension() == NDimSpatial + 3)) + { + throw std::runtime_error("wrong! inconsistent dimension"); + } + + if constexpr(NDimSpatial == 1) + { + auto func = [&](auto g, auto n, auto k, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c) + { + for(std::size_t x = 0; x < weight.get_lengths()[3]; ++x) + { + auto wi = static_cast(wo * conv_strides[0]) + + static_cast(x * conv_dilations[0]) - + static_cast(in_left_pads[0]); + + if(wi >= 0 && ck_tile::type_convert(wi) < input.get_lengths()[3]) + { + InDataType v_in = input(g, n, c, wi); + WeiDataType v_wei = weight(g, k, c, x); + v_acc += ck_tile::type_convert(v_in) * + ck_tile::type_convert(v_wei); + } + } + } + OutDataType v_acc_converted = ck_tile::type_convert(v_acc); + output(g, n, k, wo) = v_acc_converted; + }; + + make_ParallelTensorFunctor(func, + output.get_lengths()[0], + output.get_lengths()[1], + output.get_lengths()[2], + output.get_lengths()[3])(std::thread::hardware_concurrency()); + } + else if constexpr(NDimSpatial == 2) + { + auto func = [&](auto g, auto n, auto k, auto ho, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c) + { + for(std::size_t y = 0; y < weight.get_lengths()[3]; ++y) + { + auto hi = static_cast(ho * conv_strides[0]) + + static_cast(y * conv_dilations[0]) - + static_cast(in_left_pads[0]); + + for(std::size_t x = 0; x < weight.get_lengths()[4]; ++x) + { + auto wi = static_cast(wo * conv_strides[1]) + + static_cast(x * conv_dilations[1]) - + static_cast(in_left_pads[1]); + + if(hi >= 0 && + ck_tile::type_convert(hi) < input.get_lengths()[3] && + wi >= 0 && + ck_tile::type_convert(wi) < input.get_lengths()[4]) + { + InDataType v_in = input(g, n, c, hi, wi); + WeiDataType v_wei = weight(g, k, c, y, x); + + v_acc += ck_tile::type_convert(v_in) * + ck_tile::type_convert(v_wei); + } + } + } + } + OutDataType v_acc_converted = ck_tile::type_convert(v_acc); + output(g, n, k, ho, wo) = v_acc_converted; + }; + + make_ParallelTensorFunctor(func, + output.get_lengths()[0], + output.get_lengths()[1], + output.get_lengths()[2], + output.get_lengths()[3], + output.get_lengths()[4])(std::thread::hardware_concurrency()); + } + else if constexpr(NDimSpatial == 3) + { + auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) { + float v_acc = 0; + + for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c) + { + for(std::size_t z = 0; z < weight.get_lengths()[3]; ++z) + { + auto di = static_cast(d_o * conv_strides[0]) + + static_cast(z * conv_dilations[0]) - + static_cast(in_left_pads[0]); + for(std::size_t y = 0; y < weight.get_lengths()[4]; ++y) + { + auto hi = static_cast(ho * conv_strides[1]) + + static_cast(y * conv_dilations[1]) - + static_cast(in_left_pads[1]); + for(std::size_t x = 0; x < weight.get_lengths()[5]; ++x) + { + auto wi = static_cast(wo * conv_strides[2]) + + static_cast(x * conv_dilations[2]) - + static_cast(in_left_pads[2]); + if(di >= 0 && + ck_tile::type_convert(di) < input.get_lengths()[3] && + hi >= 0 && + ck_tile::type_convert(hi) < input.get_lengths()[4] && + wi >= 0 && + ck_tile::type_convert(wi) < input.get_lengths()[5]) + { + InDataType v_in = input(g, n, c, di, hi, wi); + WeiDataType v_wei = weight(g, k, c, z, y, x); + + v_acc += ck_tile::type_convert(v_in) * + ck_tile::type_convert(v_wei); + } + } + } + } + } + OutDataType v_acc_converted = ck_tile::type_convert(v_acc); + output(g, n, k, d_o, ho, wo) = v_acc_converted; + }; + + make_ParallelTensorFunctor(func, + output.get_lengths()[0], + output.get_lengths()[1], + output.get_lengths()[2], + output.get_lengths()[3], + output.get_lengths()[4], + output.get_lengths()[5])(std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error("Ref_Conv_fwd: number of dimensions must be between 1 and 3."); + } +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 68e91520bf..bf58544259 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -27,7 +27,9 @@ template + index_t kNumWaveGroups_ = 1, + bool FixedVectorSize_ = false, + index_t VectorSizeC_ = 1> struct CShuffleEpilogueProblem { using ADataType = remove_cvref_t; @@ -48,6 +50,8 @@ struct CShuffleEpilogueProblem static constexpr index_t KPerXdl = KPerXdl_; static constexpr index_t isCTransposed = isCTransposed_; static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; + static constexpr bool FixedVectorSize = FixedVectorSize_; + static constexpr index_t VectorSizeC = VectorSizeC_; static constexpr index_t kNumWaveGroups = kNumWaveGroups_; static constexpr index_t NumDTensor = DsDataType::size(); @@ -80,6 +84,8 @@ struct CShuffleEpilogue static constexpr index_t NPerXdl = Problem::NPerXdl; static constexpr index_t KPerXdl = Problem::KPerXdl; static constexpr index_t isCTransposed = Problem::isCTransposed; + static constexpr bool FixedVectorSize = Problem::FixedVectorSize; + static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t MPerIteration = MPerXdl * MWave; static constexpr index_t NPerIteration = NPerXdl * NWave; static constexpr index_t NumDTensor = Problem::NumDTensor; @@ -98,6 +104,10 @@ struct CShuffleEpilogue */ CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC() { + if constexpr(FixedVectorSize) + { + return VectorSizeC; + } constexpr index_t max_vector_size = 16; if constexpr(std::is_same_v) { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp index 55220730cd..424565060b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -1,8 +1,7 @@ #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" #include "ck_tile/host/concat.hpp" namespace ck_tile { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index 6bb14af9e6..0f7f6369f0 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -121,7 +121,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy if constexpr(std::is_same_v) { - constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M1 = Problem::VectorSizeA; constexpr index_t M0 = MPerBlock / M1; constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; static_assert(total_pixels % M1 == 0); @@ -211,7 +211,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy if constexpr(std::is_same_v) { - constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t N1 = Problem::VectorSizeB; constexpr index_t N0 = NPerBlock / N1; constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; static_assert(total_pixels % N1 == 0); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index b10ee0320f..dc7d150b46 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -14,7 +14,10 @@ template + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1> struct GemmPipelineProblemBase { using Traits = remove_cvref_t; @@ -24,6 +27,8 @@ struct GemmPipelineProblemBase using CDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; + static constexpr bool FixedVectorSize = FixedVectorSize_; + using BlockGemmShape = remove_cvref_t; using ALayout = remove_cvref_t; @@ -115,7 +120,11 @@ struct GemmPipelineProblemBase } static constexpr index_t VectorSizeA = []() { - if constexpr(std::is_same_v) + if constexpr(FixedVectorSize) + { + return VectorSizeA_; + } + else if constexpr(std::is_same_v) { return kPadK ? 1 : GetAlignmentA(); } @@ -126,7 +135,11 @@ struct GemmPipelineProblemBase }(); static constexpr index_t VectorSizeB = []() { - if constexpr(std::is_same_v) + if constexpr(FixedVectorSize) + { + return VectorSizeB_; + } + else if constexpr(std::is_same_v) { return kPadN ? 1 : GetAlignmentB(); } @@ -153,13 +166,19 @@ template + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1> using GemmPipelineProblem = GemmPipelineProblemBase; + ComputeDataType_, + FixedVectorSize_, + VectorSizeA_, + VectorSizeB_>; template + typename ComputeDataType_ = ADataType_, + bool FixedVectorSize_ = false, + index_t VectorSizeA_ = 1, + index_t VectorSizeB_ = 1> struct UniversalGemmPipelineProblem { using Traits = remove_cvref_t; @@ -179,6 +201,10 @@ struct UniversalGemmPipelineProblem using CDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; + static constexpr bool FixedVectorSize = FixedVectorSize_; + static constexpr index_t VectorSizeA = VectorSizeA_; + static constexpr index_t VectorSizeB = VectorSizeB_; + using BlockGemmShape = remove_cvref_t; using ALayout = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 91e845d200..d5f2eedf2d 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -426,10 +426,11 @@ struct UniversalGemmBasePolicy { using ALayout = remove_cvref_t; - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = GetVectorSizeA(); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA(); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; // Tile: MPerBlock X KPerBlock @@ -461,10 +462,11 @@ struct UniversalGemmBasePolicy { using BLayout = remove_cvref_t; - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = GetVectorSizeB(); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; // Tile: KPerBlock X NPerBlock diff --git a/include/ck_tile/ops/grouped_convolution.hpp b/include/ck_tile/ops/grouped_convolution.hpp new file mode 100644 index 0000000000..ae5720776c --- /dev/null +++ b/include/ck_tile/ops/grouped_convolution.hpp @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp" +#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" +#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" +#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp new file mode 100644 index 0000000000..196c468c07 --- /dev/null +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -0,0 +1,800 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/core/utility/env.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" +#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" + +namespace ck_tile { + +/// @brief The Grouped Convolution kernel device arguments. +template +struct GroupedConvFwdKernelArgs +{ + + using ConvToGemmFwdTransformer = + TransformConvFwdToGemm; + static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; + + template < + typename InLay = typename GroupedConvTraitsType::InLayout, + typename WeiLay = typename GroupedConvTraitsType::WeiLayout, + typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename std::enable_if && + std::is_same_v && + std::is_same_v, + bool>::type = false> + CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args) + { + in_g_n_c_wis_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0])}; + wei_g_k_c_xs_lengths = {static_cast(args.G_), + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0])}; + out_g_n_k_wos_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.K_), + static_cast(args.output_spatial_lengths_[0])}; + + conv_filter_strides = {static_cast(args.conv_filter_strides_[0])}; + conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0])}; + input_left_pads = {static_cast(args.input_left_pads_[0])}; + input_right_pads = {static_cast(args.input_right_pads_[0])}; + + k_batch = args.k_batch; + + GemmM = args.N_ * args.output_spatial_lengths_[0]; + GemmN = args.K_; + GemmK = args.C_ * args.filter_spatial_lengths_[0]; + + in_ptr = args.in_ptr; + wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } + out_ptr = args.out_ptr; + + ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, + wei_g_k_c_xs_lengths, + out_g_n_k_wos_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + a_grid_desc_m_k = + conv_to_gemm_transformer + .template MakeADescriptor_M_K(); + b_grid_desc_n_k = + conv_to_gemm_transformer + .template MakeBDescriptor_N_K(); + c_grid_desc_m_n = + conv_to_gemm_transformer + .template MakeCDescriptor_M_N(); + + group_stride_a = args.C_; + group_stride_b = args.K_ * args.C_ * + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); + group_stride_c = args.K_; + } + + template < + typename InLay = typename GroupedConvTraitsType::InLayout, + typename WeiLay = typename GroupedConvTraitsType::WeiLayout, + typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename std::enable_if && + std::is_same_v && + std::is_same_v, + bool>::type = false> + CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args) + { + in_g_n_c_wis_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1])}; + wei_g_k_c_xs_lengths = {static_cast(args.G_), + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1])}; + out_g_n_k_wos_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.K_), + static_cast(args.output_spatial_lengths_[0]), + static_cast(args.output_spatial_lengths_[1])}; + + conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), + static_cast(args.conv_filter_strides_[1])}; + conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), + static_cast(args.conv_filter_dilations_[1])}; + input_left_pads = {static_cast(args.input_left_pads_[0]), + static_cast(args.input_left_pads_[1])}; + input_right_pads = {static_cast(args.input_right_pads_[0]), + static_cast(args.input_right_pads_[1])}; + + k_batch = args.k_batch; + + GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; + GemmN = args.K_; + GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1]; + + in_ptr = args.in_ptr; + wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } + out_ptr = args.out_ptr; + + ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, + wei_g_k_c_xs_lengths, + out_g_n_k_wos_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + a_grid_desc_m_k = + conv_to_gemm_transformer + .template MakeADescriptor_M_K(); + b_grid_desc_n_k = + conv_to_gemm_transformer + .template MakeBDescriptor_N_K(); + c_grid_desc_m_n = + conv_to_gemm_transformer + .template MakeCDescriptor_M_N(); + + group_stride_a = args.C_; + group_stride_b = args.K_ * args.C_ * + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); + group_stride_c = args.K_; + } + + template < + typename InLay = typename GroupedConvTraitsType::InLayout, + typename WeiLay = typename GroupedConvTraitsType::WeiLayout, + typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename std::enable_if && + std::is_same_v && + std::is_same_v, + bool>::type = false> + CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args) + { + in_g_n_c_wis_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1]), + static_cast(args.input_spatial_lengths_[2])}; + wei_g_k_c_xs_lengths = {static_cast(args.G_), + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1]), + static_cast(args.filter_spatial_lengths_[2])}; + out_g_n_k_wos_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.K_), + static_cast(args.output_spatial_lengths_[0]), + static_cast(args.output_spatial_lengths_[1]), + static_cast(args.output_spatial_lengths_[2])}; + + conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), + static_cast(args.conv_filter_strides_[1]), + static_cast(args.conv_filter_strides_[2])}; + conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), + static_cast(args.conv_filter_dilations_[1]), + static_cast(args.conv_filter_dilations_[2])}; + input_left_pads = {static_cast(args.input_left_pads_[0]), + static_cast(args.input_left_pads_[1]), + static_cast(args.input_left_pads_[2])}; + input_right_pads = {static_cast(args.input_right_pads_[0]), + static_cast(args.input_right_pads_[1]), + static_cast(args.input_right_pads_[2])}; + + k_batch = args.k_batch; + + GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] * + args.output_spatial_lengths_[2]; + GemmN = args.K_; + GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] * + args.filter_spatial_lengths_[2]; + + in_ptr = args.in_ptr; + wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } + out_ptr = args.out_ptr; + + ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, + wei_g_k_c_xs_lengths, + out_g_n_k_wos_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + a_grid_desc_m_k = + conv_to_gemm_transformer + .template MakeADescriptor_M_K(); + b_grid_desc_n_k = + conv_to_gemm_transformer + .template MakeBDescriptor_N_K(); + c_grid_desc_m_n = + conv_to_gemm_transformer + .template MakeCDescriptor_M_N(); + + group_stride_a = args.C_; + group_stride_b = args.K_ * args.C_ * + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); + group_stride_c = args.K_; + } + + using AGridDescMK = remove_cvref_t())>; + using BGridDescNK = remove_cvref_t())>; + using CGridDescMN = remove_cvref_t())>; + + static constexpr index_t NonSpatialDims = 3; + array in_g_n_c_wis_lengths; + array wei_g_k_c_xs_lengths; + array out_g_n_k_wos_lengths; + + array conv_filter_strides; + array conv_filter_dilations; + array input_left_pads; + array input_right_pads; + + index_t k_batch; + index_t GemmM; + index_t GemmN; + index_t GemmK; + + const void* in_ptr; + const void* wei_ptr; + std::array ds_ptr; + void* out_ptr; + + AGridDescMK a_grid_desc_m_k; + BGridDescNK b_grid_desc_n_k; + CGridDescMN c_grid_desc_m_n; + + long_index_t group_stride_a; + long_index_t group_stride_b; + long_index_t group_stride_c; +}; + +/// @brief The Grouped Convolution Forward kernel template. +/// +/// @paragraph Overview Overview +/// This class provides the grouped convolution forward kernel template. By semantic +/// division of Implicit GEMM algorithm into following parts we achieve flexible, +/// versatile and robust kernel implementation. +/// +/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator() +/// function call operator" which determines the work scope of each workgroup. +/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm. +/// This is the place where each workgroup is loading data from global memory and +/// carrying out dot products. +/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation +/// responsible for storing results to global memory. This is also the place where +/// any additional operator fusion may take place. +/// +/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_ +/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all +/// internal details of those functional parts. You can think of it like both gemm and +/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover +/// the policy is responsible for definition of all necessary data layouts and thread's +/// work distribution. +/// +/// @tparam GroupedConvTraitsType The type of class providing traits for grouped convolution. +/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into +/// the +/// output data tile to be calculated. It determines the +/// workgroup to data relationship (or in other words - which +/// data would be processed and calculated by which workgroup). +/// @tparam GemmPipeline_ The type of class which provides the core part of matrix +/// multiplication. This class should provide implementation of +/// data loading from global memory and performing block-wise +/// matrix multiplication. You can think of it as a work done by +/// single workgroup point of view. +/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix +/// multiplication implementation. It is responsible for storing +/// results calculated by @ref GemmPipeline_ "GemmPipeline" to +/// the output C tensor in global memory. +template +struct GroupedConvolutionForwardKernel +{ + static constexpr index_t NDimSpatial = GroupedConvTraitsType::NDimSpatial; + static constexpr ConvolutionSpecialization ConvSpecialization = + GroupedConvTraitsType::ConvSpecialization; + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using GemmALayout = remove_cvref_t; + using GemmBLayout = remove_cvref_t; + using GemmCLayout = remove_cvref_t; + + using InLayout = remove_cvref_t; + using WeiLayout = remove_cvref_t; + using OutLayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + + using GemmDsLayout = remove_cvref_t; + + static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; + + static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + + using InDataType = remove_cvref_t; + using WeiDataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + // Below type is actually accumulation data type - the output of block GEMM. + using OutDataType = remove_cvref_t; + + using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs; + + // TODO: Enable this + static constexpr bool IsSplitKSupported = false; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto I3 = number<3>(); + + static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK, + "Not supported!"); + static_assert(std::is_same_v, "Not supported!"); + static_assert(std::is_same_v, "Not supported!"); + static_assert(std::is_same_v, "Not supported!"); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "grouped_convolution_forward", gemm_prec_str, GemmPipeline::GetName()); + // clang-format on + } + + CK_TILE_HOST static constexpr auto GridSize(const GroupedConvHostArgs& args) + { + const index_t GemmM = args.N_ * std::accumulate(args.output_spatial_lengths_.begin(), + args.output_spatial_lengths_.end(), + 1, + std::multiplies()); + const index_t GemmN = args.K_; + return dim3(TilePartitioner::GridSize(GemmM, GemmN), args.G_, args.k_batch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + + CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized + MakeKernelArgs(const GroupedConvHostArgs& hostArgs) + { + return GroupedConvFwdKernelArgsSpecialized(hostArgs); + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvFwdKernelArgsSpecialized& kargs) + { + if constexpr((EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value) || + !IsSplitKSupported) + { + if(kargs.k_batch != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + } + return false; + } + } + + const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}]; + const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}]; + + // check ConvolutionSpecialization + if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3]; + const index_t ConvStride = kargs.conv_filter_strides[i]; + const index_t LeftPad = kargs.input_left_pads[i]; + const index_t RightPad = kargs.input_right_pads[i]; + + if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3]; + const index_t LeftPad = kargs.input_left_pads[i]; + const index_t RightPad = kargs.input_right_pads[i]; + + if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3) + { + if(ConvC != 1) + { + return false; + } + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3]; + + if(filter_spatial_dim != I3) + { + return false; + } + } + } + + namespace ctc = tensor_layout::convolution; + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v) + { + // Check access per C + if(ConvC % GemmPipeline::GetVectorSizeA() != 0) + { + CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!"); + return false; + } + } + else + { + CK_TILE_ERROR("Not supported input layout!"); + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(ConvC % GemmPipeline::GetVectorSizeB() != 0) + { + CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!"); + return false; + } + } + else + { + CK_TILE_ERROR("Not supported weight layout!"); + return false; + } + + // check vector access of E + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(ConvK % EpiloguePipeline::GetVectorSizeC() != 0) + { + CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!"); + return false; + } + } + else + { + CK_TILE_ERROR("Not supported output layout!"); + return false; + } + + return true; + } + + template + CK_TILE_DEVICE static auto + MakeGemmTensorViews(const InDataType* a_ptr, + const WeiDataType* b_ptr, + const std::array& ds_ptr, + OutDataType* c_ptr, + const GroupedConvFwdKernelArgsSpecialized& kargs) + { + static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); + static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!"); + const auto& a_tensor_view = [&]() { + return make_tensor_view(a_ptr, kargs.a_grid_desc_m_k); + }(); + + const auto& b_tensor_view = [&]() { + return make_tensor_view(b_ptr, kargs.b_grid_desc_n_k); + }(); + + // TODO: enable vector write for C in ColMajor + const auto& c_tensor_view = [&]() { + return make_tensor_view(c_ptr, kargs.c_grid_desc_m_n); + }(); + + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + static_assert(std::is_same_v, OutLayout>, + "Not supported!"); + static_assert(std::is_same_v, + "Not supported!"); + static_assert(std::is_same_v, OutDataType>, + "Not supported!"); + + return make_tensor_view( + static_cast(ds_ptr[i]), kargs.c_grid_desc_m_n); + }, + number{}); + + return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) + { + const auto& a_pad_view = [&]() { + const auto& a_tensor_view = views.at(I0); + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + const auto& b_pad_view = [&]() { + const auto& b_tensor_view = views.at(I1); + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + const auto& ds_tensor_view = views.at(I2); + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + }, + number{}); + + const auto& c_pad_view = [&]() { + const auto& c_tensor_view = views.at(I3); + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view); + } + + template + CK_TILE_DEVICE static auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) + { + const auto& a_pad_view = views.at(I0); + const auto& b_pad_view = views.at(I1); + const auto& ds_pad_view = views.at(I2); + const auto& c_pad_view = views.at(I3); + + const auto& a_block_window = [&]() { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); + }(); + + const auto& b_block_window = [&]() { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); + }(); + + const auto ds_block_window = generate_tuple( + [&](auto i) { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, i_n}); + }, + number{}); + + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param kargs Grouped Convolution Forward kernel arguments + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void RunGemm(const InDataType* a_ptr, + const WeiDataType* b_ptr, + const std::array& ds_ptr, + OutDataType* c_ptr, + void* smem_ptr_0, + const GroupedConvFwdKernelArgsSpecialized& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + a_ptr, b_ptr, ds_ptr, c_ptr, kargs); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = + __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK)); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, smem_ptr_0); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The starting pointer of 1st shared memory block. + * @param smem_ptr_1 The starting pointer of 2nd shared memory block. + * @param kargs Grouped Convolution Forward kernel arguments + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void RunGemm2LDS(const InDataType* a_ptr, + const WeiDataType* b_ptr, + const std::array& ds_ptr, + OutDataType* c_ptr, + void* __restrict__ smem_ptr_0, + void* __restrict__ smem_ptr_1, + const GroupedConvFwdKernelArgsSpecialized& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + a_ptr, b_ptr, ds_ptr, c_ptr, kargs); + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = + __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK)); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0, smem_ptr_1); + } + + CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const + { + const auto blockIdX = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto [iM, iN] = + TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const auto blockIdY = __builtin_amdgcn_readfirstlane(blockIdx.y); + const auto group_offset_a = __builtin_amdgcn_readfirstlane(kargs.group_stride_a * blockIdY); + const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY); + const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY); + + // options + const InDataType* a_ptr = static_cast(kargs.in_ptr) + group_offset_a; + const WeiDataType* b_ptr = static_cast(kargs.wei_ptr) + group_offset_b; + OutDataType* c_ptr = static_cast(kargs.out_ptr) + group_offset_c; + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + __shared__ char smem_ptr_1[GetSmemSize()]; + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm2LDS( + a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, i_m, i_n); + } + } + else + { + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n); + } + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp b/include/ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp new file mode 100644 index 0000000000..4cbc5c506a --- /dev/null +++ b/include/ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck_tile { + +enum struct ConvolutionSpecialization +{ + Default, + Filter1x1Pad0, + Filter1x1Stride1Pad0, + Filter3x3, +}; + +CK_TILE_HOST std::string getConvSpecializationString(const ConvolutionSpecialization& s) +{ + switch(s) + { + case ConvolutionSpecialization::Default: return "Default"; + case ConvolutionSpecialization::Filter1x1Pad0: return "Filter1x1Pad0"; + case ConvolutionSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; + case ConvolutionSpecialization::Filter3x3: return "Filter3x3"; + default: return "Unrecognized specialization!"; + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp new file mode 100644 index 0000000000..4b7cb3c895 --- /dev/null +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/convolution_parameter.hpp" + +namespace ck_tile { + +/// @brief The Grouped Conv kernel host arguments. +/// +/// @par Overview +/// This structure is passed to Grouped Convolution Kernels when creating kernel +/// arguments object. It contain all necessary information required to +/// build proper kernel argument and launch kernel on GPU. +struct GroupedConvHostArgs : public conv::ConvParam +{ + CK_TILE_HOST GroupedConvHostArgs() = delete; + CK_TILE_HOST GroupedConvHostArgs(ConvParam conv_param, + const void* in_ptr_, + const void* wei_ptr_, + const std::vector ds_ptr_, + void* out_ptr_, + index_t k_batch_) + : conv::ConvParam(conv_param), + in_ptr(in_ptr_), + wei_ptr(wei_ptr_), + ds_ptr(ds_ptr_), + out_ptr(out_ptr_), + k_batch(k_batch_) + { + } + + const void* in_ptr; + const void* wei_ptr; + const std::vector ds_ptr; + void* out_ptr; + index_t k_batch; +}; + +template +struct GroupedConvTraits +{ + private: + static constexpr auto generate_implicit_gemm_layout() + { + return generate_tuple([](auto) { return ck_tile::tensor_layout::gemm::RowMajor{}; }, + number{}); + } + + public: + static constexpr index_t NDimSpatial = NDimSpatial_; + static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_; + using InLayout = InLayout_; + using WeiLayout = WeiLayout_; + using DsLayout = DsLayout_; + using OutLayout = OutLayout_; + using GroupedConvImplicitGemmTraits = TileGemmTraits; + static constexpr index_t NumDTensor = DsLayout::size(); + using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout()); +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp new file mode 100644 index 0000000000..c468ae4398 --- /dev/null +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp @@ -0,0 +1,1432 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck_tile/core.hpp" +#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" + +namespace ck_tile { + +template +struct TransformConvFwdToGemm +{ + private: + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + static constexpr auto I3 = number<3>{}; + static constexpr auto I4 = number<4>{}; + static constexpr auto I5 = number<5>{}; +#if 0 // TODO: Enable these functionalities + template + static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths, + const ConvDimsType& strides, + index_t i) + { + long_index_t acc = 1; + for(; i < (NDimSpatial + 3); i++) + { + acc += + static_cast(lengths[i] - I1) * static_cast(strides[i]); + } + + return acc; + } + + template + static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& a_g_n_c_wis_strides, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& c_g_n_k_wos_strides) + { + const long_index_t a_element_space_size = + calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1); + const long_index_t c_element_space_size = + calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1); + const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType), + c_element_space_size * sizeof(CDataType)); + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const IndexType N = a_g_n_c_wis_lengths[I1]; + + if(element_space_size > TwoGB) + { + // Minimum divisor of N to not exceed 2GB + const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB); + + if(divisor <= static_cast(N)) + { + // Find least divisor of N larger than element_space_size / TwoGB + // Iterate up to sqrt(N). There are no divisors above this value. + for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N; + least_divisor++) + { + if(N % least_divisor == 0) + { + return N / least_divisor; + } + } + // Not found, process one Convolution N per block + return 1; + } + else + { + // Split Convolution's N dimension into N workgroups. However + // this still might not result in sufficiently small tensor, + // but at least later on we could divide the image as well. + return 1; + } + } + else + { + // Split N is not needed. + return N; + } + } +#endif + + public: + CK_TILE_HOST constexpr TransformConvFwdToGemm() {} + + template + CK_TILE_HOST + TransformConvFwdToGemm(const TransformConvFwdToGemmBase& transform_conv_fwd_to_gemm_base) + : G_{static_cast(transform_conv_fwd_to_gemm_base.G_)}, + N_{static_cast(transform_conv_fwd_to_gemm_base.N_)}, + Di_{static_cast(transform_conv_fwd_to_gemm_base.Di_)}, + Hi_{static_cast(transform_conv_fwd_to_gemm_base.Hi_)}, + Wi_{static_cast(transform_conv_fwd_to_gemm_base.Wi_)}, + Do_{static_cast(transform_conv_fwd_to_gemm_base.Do_)}, + Ho_{static_cast(transform_conv_fwd_to_gemm_base.Ho_)}, + Wo_{static_cast(transform_conv_fwd_to_gemm_base.Wo_)}, + Z_{static_cast(transform_conv_fwd_to_gemm_base.Z_)}, + Y_{static_cast(transform_conv_fwd_to_gemm_base.Y_)}, + X_{static_cast(transform_conv_fwd_to_gemm_base.X_)}, + K_{static_cast(transform_conv_fwd_to_gemm_base.K_)}, + C_{static_cast(transform_conv_fwd_to_gemm_base.C_)}, + ConvStrideD_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideD_)}, + ConvStrideH_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideH_)}, + ConvStrideW_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideW_)}, + ConvDilationD_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationD_)}, + ConvDilationH_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationH_)}, + ConvDilationW_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationW_)}, + InLeftPadD_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadD_)}, + InLeftPadH_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadH_)}, + InLeftPadW_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadW_)}, + InRightPadD_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadD_)}, + InRightPadH_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadH_)}, + InRightPadW_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadW_)}, + ZYX_{static_cast(transform_conv_fwd_to_gemm_base.ZYX_)} + { + } + + template ::type = false> + CK_TILE_HOST TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : G_{a_g_n_c_wis_lengths[I0]}, + Di_{I1}, + Hi_{I1}, + Wi_{a_g_n_c_wis_lengths[I3]}, + Do_{I1}, + Ho_{I1}, + Wo_{c_g_n_k_wos_lengths[I3]}, + Z_{I1}, + Y_{I1}, + X_{b_g_k_c_xs_lengths[I3]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + ConvStrideD_{I1}, + ConvStrideH_{I1}, + ConvStrideW_{conv_filter_strides[I0]}, + ConvDilationD_{I1}, + ConvDilationH_{I1}, + ConvDilationW_{conv_filter_dilations[I0]}, + InLeftPadD_{I0}, + InLeftPadH_{I0}, + InLeftPadW_{input_left_pads[I0]}, + InRightPadD_{I0}, + InRightPadH_{I0}, + InRightPadW_{input_right_pads[I0]}, + ZYX_{X_} + { + static_assert(std::is_same_v> || + std::is_same_v>); + static_assert(std::is_same_v> || + std::is_same_v>); +#if 0 // TODO: Enable these functionalities + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } +#endif + N_ = c_g_n_k_wos_lengths[I1]; + } + + template ::type = false> + CK_TILE_HOST TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : G_{a_g_n_c_wis_lengths[I0]}, + Di_{I1}, + Hi_{a_g_n_c_wis_lengths[I3]}, + Wi_{a_g_n_c_wis_lengths[I4]}, + Do_{I1}, + Ho_{c_g_n_k_wos_lengths[I3]}, + Wo_{c_g_n_k_wos_lengths[I4]}, + Z_{I1}, + Y_{b_g_k_c_xs_lengths[I3]}, + X_{b_g_k_c_xs_lengths[I4]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + ConvStrideD_{I1}, + ConvStrideH_{conv_filter_strides[I0]}, + ConvStrideW_{conv_filter_strides[I1]}, + ConvDilationD_{I1}, + ConvDilationH_{conv_filter_dilations[I0]}, + ConvDilationW_{conv_filter_dilations[I1]}, + InLeftPadD_{I0}, + InLeftPadH_{input_left_pads[I0]}, + InLeftPadW_{input_left_pads[I1]}, + InRightPadD_{I0}, + InRightPadH_{input_right_pads[I0]}, + InRightPadW_{input_right_pads[I1]}, + ZYX_{Y_ * X_} + { + static_assert(std::is_same_v> || + std::is_same_v>); + static_assert(std::is_same_v> || + std::is_same_v>); +#if 0 // TODO: Enable these functionalities + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } +#endif + N_ = c_g_n_k_wos_lengths[I1]; + } + + template ::type = false> + CK_TILE_HOST TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : G_{a_g_n_c_wis_lengths[I0]}, + Di_{a_g_n_c_wis_lengths[I3]}, + Hi_{a_g_n_c_wis_lengths[I4]}, + Wi_{a_g_n_c_wis_lengths[I5]}, + Do_{c_g_n_k_wos_lengths[I3]}, + Ho_{c_g_n_k_wos_lengths[I4]}, + Wo_{c_g_n_k_wos_lengths[I5]}, + Z_{b_g_k_c_xs_lengths[I3]}, + Y_{b_g_k_c_xs_lengths[I4]}, + X_{b_g_k_c_xs_lengths[I5]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + ConvStrideD_{conv_filter_strides[I0]}, + ConvStrideH_{conv_filter_strides[I1]}, + ConvStrideW_{conv_filter_strides[I2]}, + ConvDilationD_{conv_filter_dilations[I0]}, + ConvDilationH_{conv_filter_dilations[I1]}, + ConvDilationW_{conv_filter_dilations[I2]}, + InLeftPadD_{input_left_pads[I0]}, + InLeftPadH_{input_left_pads[I1]}, + InLeftPadW_{input_left_pads[I2]}, + InRightPadD_{input_right_pads[I0]}, + InRightPadH_{input_right_pads[I1]}, + InRightPadW_{input_right_pads[I2]}, + ZYX_{Z_ * Y_ * X_} + { + static_assert(std::is_same_v> || + std::is_same_v>); + static_assert(std::is_same_v> || + std::is_same_v>); +#if 0 // TODO: Enable these functionalities + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } +#endif + N_ = c_g_n_k_wos_lengths[I1]; + } + +#if 0 // TODO: Enable these functionalities + __host__ bool AreDescriptorsSmallerThan2GB() const + { + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const long_index_t in_desc_space_size = + I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ + + (Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_; + const long_index_t out_desc_space_size = + I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ + + (Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_; + + bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB; + bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB; + + return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB; + } + + __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base, + CDataType* c_grid_ptr_base) const + { + // Create copies + auto conv_to_gemm_transformer_left = *this; + auto conv_to_gemm_transformer_right = *this; + IndexType a_right_offset = 0; + IndexType c_right_offset = 0; + // Calculate real filter size + const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1; + const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1; + const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1; + // Calculate start position in input for right tensor + const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_; + const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_; + const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_; + // Calculate last position in input for left tensor + const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff; + const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff; + const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff; + // Allow to split if whole left padding will be in left tensor and right padding in right + // tensor + const bool is_possible_to_split_d = Do_ != 1 && + di_right_transformer_start_idx > InLeftPadD_ && + di_left_transformer_end_idx <= (InLeftPadD_ + Di_); + const bool is_possible_to_split_h = Ho_ != 1 && + hi_right_transformer_start_idx > InLeftPadH_ && + hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_); + const bool is_possible_to_split_w = Wo_ != 1 && + wi_right_transformer_start_idx > InLeftPadW_ && + wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_); + + if(is_possible_to_split_d) + { + // Apply new sizes + // Split output on half + conv_to_gemm_transformer_left.Do_ = Do_ / 2; + conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2; + // Assign left padding to left convolution + conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_; + conv_to_gemm_transformer_right.InLeftPadD_ = 0; + // Assign right padding to right convolution + conv_to_gemm_transformer_left.InRightPadD_ = 0; + conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_; + // Calculate new input size + conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_; + conv_to_gemm_transformer_right.Di_ = + math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_), + (conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff); + ; + // Calcualte offsets + a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_; + c_right_offset = (Do_ / 2) * DoStride_; + } + else if(is_possible_to_split_h) + { + conv_to_gemm_transformer_left.Ho_ = Ho_ / 2; + conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2; + + conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_; + conv_to_gemm_transformer_right.InLeftPadH_ = 0; + + conv_to_gemm_transformer_left.InRightPadH_ = 0; + conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_; + + conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_; + conv_to_gemm_transformer_right.Hi_ = + math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_), + (conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff); + a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_; + c_right_offset = (Ho_ / 2) * HoStride_; + } + else if(is_possible_to_split_w) + { + conv_to_gemm_transformer_left.Wo_ = Wo_ / 2; + conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2; + + conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_; + conv_to_gemm_transformer_right.InLeftPadW_ = 0; + + conv_to_gemm_transformer_left.InRightPadW_ = 0; + conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_; + + conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_; + conv_to_gemm_transformer_right.Wi_ = + math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_), + (conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff); + + a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_; + c_right_offset = (Wo_ / 2) * WoStride_; + } + // Return left transform, right transformer, right offset to Input and right offset to + // Output + return ck_tile::make_tuple(conv_to_gemm_transformer_left, + conv_to_gemm_transformer_right, + a_grid_ptr_base + a_right_offset, + c_grid_ptr_base + c_right_offset); + } +#endif + // TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as + // properties + template , + bool>::type = false> + CK_TILE_HOST auto MakeADescriptor_M_K() const + { + IndexType WiStride_ = G_ * C_; + IndexType CStrideTensorA_ = 1; + IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_; + IndexType GStrideTensorA_ = C_; + + if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wo_, C_), + make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_)); + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wo_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3) + { + if constexpr(NumGroupsToMerge == 1) + { + + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_), make_tuple(NStrideTensorA_, WiStride_)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(number<3>{})), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(number<3>{})), + make_tuple(sequence<0, 2, 3>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, C_), + make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_)); + + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + return transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, C_), + make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_merge_transform(make_tuple(X_, C_))), + make_tuple(sequence<0, 2>{}, sequence<1, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}, sequence<4>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(X_, C_))), + make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + } + + template , + bool>::type = false> + CK_TILE_HOST auto MakeADescriptor_M_K() const + + { + IndexType HiStride_ = Wi_ * G_ * C_; + IndexType WiStride_ = G_ * C_; + IndexType CStrideTensorA_ = 1; + IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_; + IndexType GStrideTensorA_ = C_; + + if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Ho_, Wo_, C_), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, C_), + make_tuple( + NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_), make_tuple(NStrideTensorA_, HiStride_, WiStride_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_))), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_merge_transform(make_tuple(number<3>{}, number<3>{}))), + make_tuple(sequence<0, 2, 4>{}, sequence<1, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_)); + + const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hip_wip_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(number<3>{}, number<3>{}))), + make_tuple(sequence<0, 2, 4, 5>{}, sequence<1, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_)); + + const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + return transform_tensor_descriptor( + in_n_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple( + NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + const auto in_n_ho_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + return transform_tensor_descriptor( + in_n_ho_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_merge_transform(make_tuple(Y_, X_, C_))), + make_tuple(sequence<0, 2, 4>{}, sequence<1, 3, 5>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + + const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple( + NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hip_wip_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5>{}, + sequence<6>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(Y_, X_, C_))), + make_tuple(sequence<0, 2, 4, 5>{}, sequence<1, 3, 6>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + } + + template , + bool>::type = false> + CK_TILE_HOST auto MakeADescriptor_M_K() const + + { + IndexType DiStride_ = Hi_ * Wi_ * G_ * C_; + IndexType HiStride_ = Wi_ * G_ * C_; + IndexType WiStride_ = G_ * C_; + IndexType CStrideTensorA_ = 1; + IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_; + IndexType GStrideTensorA_ = C_; + + if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Do_, Ho_, Wo_, C_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, + DiStride_, + HiStride_, + WiStride_, + GStrideTensorA_, + CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2, 3, 4>{}, sequence<5>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(number<3>{}, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_))), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple( + sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5, 6>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_merge_transform(make_tuple(number<3>{}, number<3>{}, number<3>{}))), + make_tuple(sequence<0, 2, 4, 6>{}, sequence<1, 3, 5>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(number<3>{}, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(number<3>{}, number<3>{}, number<3>{}))), + make_tuple(sequence<0, 2, 4, 6, 7>{}, sequence<1, 3, 5>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_)); + + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Do_), make_tuple(ConvStrideD_)), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + return transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, + DiStride_, + HiStride_, + WiStride_, + GStrideTensorA_, + CStrideTensorA_)); + + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Do_), make_tuple(ConvStrideD_)), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{})); + + return transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2, 3, 4>{}, sequence<5>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Z_, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_merge_transform(make_tuple(Z_, Y_, X_, C_))), + make_tuple(sequence<0, 2, 4, 6>{}, sequence<1, 3, 5, 7>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, + DiStride_, + HiStride_, + WiStride_, + GStrideTensorA_, + CStrideTensorA_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Z_, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{}, + sequence<8>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(Z_, Y_, X_, C_))), + make_tuple(sequence<0, 2, 4, 6, 7>{}, sequence<1, 3, 5, 8>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + } + + template < + typename BLayout, + typename std::enable_if || + std::is_same_v || + std::is_same_v, + bool>::type = false> + CK_TILE_HOST auto MakeBDescriptor_N_K() const + { + IndexType CStrideTensorB_ = 1; + IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_; + IndexType GStrideTensorB_ = K_ * Z_ * Y_ * X_ * C_; + + if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3) + { + using FilterSizeNumType = + std::conditional_t, + std::conditional_t, number<27>>>; + + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor_packed(make_tuple(K_, FilterSizeNumType{})); + } + else + { + + const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(K_, NumGroupsToMerge, FilterSizeNumType{}), + make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_)); + return transform_tensor_descriptor( + wei_gemmn_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)), + make_pass_through_transform(FilterSizeNumType{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + else + { + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor_packed(make_tuple(K_, ZYX_ * C_)); + } + else + { + const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(K_, NumGroupsToMerge, ZYX_ * C_), + make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_)); + return transform_tensor_descriptor( + wei_gemmn_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)), + make_pass_through_transform(ZYX_ * C_)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + } + + template , + bool>::type = false> + CK_TILE_HOST auto MakeCDescriptor_M_N() const + { + IndexType WoStride_ = G_ * K_; + IndexType KStrideTensorC_ = 1; + IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_; + IndexType GStrideTensorC_ = K_; + + const IndexType NDoHoWo = N_ * Wo_; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_), + make_tuple(WoStride_, KStrideTensorC_)); + } + else + { + const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wo_, NumGroupsToMerge, K_, 1), + make_tuple( + NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + nhwo_groups_k_1_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(sequence<0, 1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + // We need only matrices from diagonal. X_or returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || + NumGroupsToMerge == 32 || NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_pass_through_transform(NDoHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), + make_merge_transform(make_tuple(K_, NumGroupsToMerge))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + + template , + bool>::type = false> + CK_TILE_HOST auto MakeCDescriptor_M_N() const + { + IndexType HoStride_ = Wo_ * G_ * K_; + IndexType WoStride_ = G_ * K_; + IndexType KStrideTensorC_ = 1; + IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_; + IndexType GStrideTensorC_ = K_; + + const IndexType NDoHoWo = N_ * Ho_ * Wo_; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_), + make_tuple(WoStride_, KStrideTensorC_)); + } + else + { + const auto nhwo_groups_k_1_desc = + make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, K_, 1), + make_tuple(NStrideTensorC_, + HoStride_, + WoStride_, + GStrideTensorC_, + KStrideTensorC_, + GStrideTensorC_)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + nhwo_groups_k_1_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}, sequence<4>{}, sequence<5>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + // We need only matrices from diagonal. X_or returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || + NumGroupsToMerge == 32 || NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_pass_through_transform(NDoHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), + make_merge_transform(make_tuple(K_, NumGroupsToMerge))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + + template , + bool>::type = false> + CK_TILE_HOST auto MakeCDescriptor_M_N() const + { + IndexType DoStride_ = Ho_ * Wo_ * G_ * K_; + IndexType HoStride_ = Wo_ * G_ * K_; + IndexType WoStride_ = G_ * K_; + IndexType KStrideTensorC_ = 1; + IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_; + IndexType GStrideTensorC_ = K_; + + const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_), + make_tuple(WoStride_, KStrideTensorC_)); + } + else + { + const auto nhwo_groups_k_1_desc = + make_naive_tensor_descriptor(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, K_, 1), + make_tuple(NStrideTensorC_, + DoStride_, + HoStride_, + WoStride_, + GStrideTensorC_, + KStrideTensorC_, + GStrideTensorC_)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + nhwo_groups_k_1_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}, sequence<5>{}, sequence<6>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + // We need only matrices from diagonal. X_or returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || + NumGroupsToMerge == 32 || NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_pass_through_transform(NDoHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), + make_merge_transform(make_tuple(K_, NumGroupsToMerge))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + } + + IndexType G_, N_; + IndexType Di_, Hi_, Wi_; + IndexType Do_, Ho_, Wo_; + IndexType Z_, Y_, X_; + IndexType K_, C_; + IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_; + IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_; + IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_; + IndexType InRightPadD_, InRightPadH_, InRightPadW_; + IndexType ZYX_; +}; + +} // namespace ck_tile diff --git a/script/dependency-parser/README.md b/script/dependency-parser/README.md new file mode 100644 index 0000000000..ff4a44b9a2 --- /dev/null +++ b/script/dependency-parser/README.md @@ -0,0 +1,173 @@ +# Dependency-based Selective Test Filtering using Static Analysis of Ninja Builds for C++ Projects + +## Overview + +This tool provides advanced dependency-based selective test filtering and build optimization for large C++ monorepos using static parsing of Ninja build files. By analyzing both source and header dependencies, it enables precise identification of which tests and executables are affected by code changes, allowing for efficient CI/CD workflows and faster incremental builds. + +The parser: +- Identifies all executables in the Ninja build. +- Maps object files to their source and header dependencies using `ninja -t deps`. +- Constructs a reverse mapping from each file to all dependent executables. +- Handles multi-executable dependencies and supports parallel processing for scalability. +- Exports results in CSV and JSON formats for integration with other tools. + +## Features + +- **Comprehensive Dependency Tracking**: Captures direct source file dependencies and, critically, all included header files via `ninja -t deps`. +- **Executable to Object Mapping**: Parses the `build.ninja` file to understand how executables are linked from object files. +- **Object to Source/Header Mapping**: Uses `ninja -t deps` for each object file to get a complete list of its dependencies. +- **File to Executable Inversion**: Inverts the dependency graph to map each file to the set of executables that depend on it. +- **Parallel Processing**: Utilizes a `ThreadPoolExecutor` to run `ninja -t deps` commands in parallel, significantly speeding up analysis for projects with many object files. +- **Filtering**: Option to filter out system files and focus on project-specific dependencies. +- **Multiple Output Formats**: + - **CSV**: `enhanced_file_executable_mapping.csv` - A comma-separated values file where each row lists a file and a semicolon-separated list of executables that depend on it. + - **JSON**: `enhanced_dependency_mapping.json` - A JSON file representing a dictionary where keys are file paths and values are lists of dependent executables. +- **Robust Error Handling**: Includes error handling for missing files and failed subprocess commands. + +## Prerequisites + +- **Python 3.7+** +- **Ninja build system**: The `ninja` executable must be in the system's PATH or its path provided as an argument. +- A **Ninja build directory** containing a `build.ninja` file and the compiled object files. The project should have been built at least once. + +## Using CMake with Ninja + +To use this tool effectively, your C++ project should be configured with CMake to generate Ninja build files and dependency information. Follow these steps: + +1. **Configure CMake to use Ninja and generate dependencies:** + ```bash + cmake -G Ninja -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_BUILD_TYPE=Release /path/to/your/source + ``` + - The `-G Ninja` flag tells CMake to generate Ninja build files. + - `-DCMAKE_EXPORT_COMPILE_COMMANDS=ON` is optional but useful for other tooling. + - Ensure your CMakeLists.txt uses `target_include_directories` and proper dependency declarations for accurate results. + +2. **Build your project with Ninja:** + ```bash + ninja + ``` + - This step is required to generate all object files and dependency information (`.d` files) that the parser relies on. + +3. **Run the dependency parser tool:** + ```bash + python main.py parse /path/to/build.ninja --workspace-root /path/to/your/workspace + ``` + +**Note:** Always run Ninja to ensure all dependencies are up to date before invoking the parser. If you change source files or headers, re-run Ninja first. + +## Usage + +All features are available via the unified main.py CLI: + +```bash +# Dependency parsing (now supports --workspace-root) +python main.py parse examples/build-ninja/build.ninja --workspace-root /path/to/your/workspace + +# Selective test filtering +python main.py select enhanced_dependency_mapping.json [--all | --test-prefix] [--output ] + +# Code auditing +python main.py audit enhanced_dependency_mapping.json + +# Build optimization +python main.py optimize enhanced_dependency_mapping.json [ ...] +``` + +**Arguments:** + +1. ``: (Required) The full path to the `build.ninja` file within your Ninja build directory. +2. `[--workspace-root ]`: (Optional, recommended) The root directory of your workspace. +3. `[path_to_ninja_executable]`: (Optional) The path to the `ninja` executable if it's not in your system's PATH. Defaults to `ninja`. + +**Example:** + +```bash +# Assuming your build directory is 'build-ninja' and it contains 'build.ninja' +python src/enhanced_ninja_parser.py build-ninja/build.ninja + +# With custom workspace root +python src/enhanced_ninja_parser.py build-ninja/build.ninja ninja /path/to/your/workspace + +# If ninja is installed in a custom location +python src/enhanced_ninja_parser.py /path/to/project/build/build.ninja /usr/local/bin/ninja +``` + +## How It Works + +1. **Initialization**: + * Takes the path to `build.ninja` and optionally the `ninja` executable. + * Sets up internal data structures to store mappings. + +2. **Build File Parsing (`_parse_build_file`)**: + * Reads the `build.ninja` file. + * Uses regular expressions to identify rules for linking executables (e.g., `build my_exe: link main.o utils.o`) and compiling object files (e.g., `build main.o: cxx ../src/main.cpp`). + * Populates `executable_to_objects` (mapping an executable name to a list of its .o files) and `object_to_source` (mapping an object file to its primary source file). + +3. **Object Dependency Extraction (`_extract_all_object_dependencies`)**: + * Iterates through all unique object files identified in the previous step. + * For each object file, it calls `_get_object_dependencies`. + * This process is parallelized using `ThreadPoolExecutor` for efficiency. Each call to `ninja -t deps` runs in a separate thread. + +4. **Individual Object Dependencies (`_get_object_dependencies`)**: + * For a given object file (e.g., `main.o`), it runs the command: `ninja -t deps main.o` in the build directory. + * This command outputs a list of all files that `main.o` depends on, including its primary source (`main.cpp`) and all headers (`*.h`, `*.hpp`) it includes directly or indirectly. + * The output is parsed, cleaned, and returned as a list of file paths. + +5. **Building Final File-to-Executable Mapping (`_build_file_to_executable_mapping`)**: + * This is the core inversion step. It iterates through each executable and its associated object files. + * For each object file, it looks up the full list of its dependencies (source and headers) obtained in step 3 & 4. + * For every dependent file found, it adds the current executable to that file's entry in the `file_to_executables` dictionary. + * If `filter_project_files` is enabled, it checks each dependency against a list of common system paths (e.g., `/usr/include`, `_deps/`) and excludes them if they match. + +6. **Filtering (`_is_project_file`)**: + * A helper function to determine if a given file path is likely a project file or a system/external library file. This helps in focusing the dependency map on the user's own codebase. + +7. **Output Generation**: + * **`export_to_csv(csv_file)`**: Writes the `file_to_executables` mapping to a CSV file. Each row contains a file path and a semicolon-delimited string of executable names. + * **`export_to_json(json_file)`**: Dumps the `file_to_executables` mapping (where the set of executables is converted to a list) into a JSON file. + * **`print_summary()`**: Prints a summary of the findings, including the number of executables, object files, source files, and header files mapped. + +## Output Files + +Running the script will generate two files in the same directory as the input `build.ninja` file: + +- **`enhanced_file_executable_mapping.csv`**: + ```csv + File,Executables + /path/to/project/src/main.cpp,my_exe_1;my_exe_2 + /path/to/project/include/utils.h,my_exe_1;another_test + ... + ``` + +- **`enhanced_dependency_mapping.json`**: + ```json + { + "/path/to/project/src/main.cpp": ["my_exe_1", "my_exe_2"], + "/path/to/project/include/utils.h": ["my_exe_1", "another_test"], + ... + } + ``` + +## Use Cases + +- **Impact Analysis**: Determine which executables (especially tests) need to be rebuilt or re-run when a specific source or header file changes. +- **Build Optimization**: Understand the dependency structure to potentially optimize build times. +- **Code Auditing**: Get a clear overview of how files are used across different executables. +- **Selective Testing**: Integrate with CI/CD systems to run only the tests affected by a given set of changes. + +## Limitations + +- Relies on the accuracy of Ninja's dependency information (`ninja -t deps`). If the build system doesn't correctly generate `.d` (dependency) files, the header information might be incomplete. +- The definition of "project file" vs. "system file" is based on a simple path-based heuristic and might need adjustment for specific project structures. +- Performance for extremely large projects (tens of thousands of object files) might still be a consideration, though parallelization helps significantly. + +## Troubleshooting + +- **"ninja: command not found"**: Ensure `ninja` is installed and in your PATH, or provide the full path to the executable as the second argument. +- **"build.ninja not found"**: Double-check the path to your `build.ninja` file. +- **Empty or Incomplete Output**: + * Make sure the project has been successfully built at least once. `ninja -t deps` relies on information generated during the build. + * Verify that your CMake (or other meta-build system) is configured to generate dependency files for Ninja. +- **Slow Performance**: For very large projects, the number of `ninja -t deps` calls can be substantial. While parallelized, it can still take time. Consider if all object files truly need to be analyzed or if a subset is sufficient for your needs. + +This tool provides a powerful way to gain deep insights into your Ninja project's dependency structure, enabling more intelligent build and test workflows. diff --git a/script/dependency-parser/main.py b/script/dependency-parser/main.py new file mode 100644 index 0000000000..b8fd67ac49 --- /dev/null +++ b/script/dependency-parser/main.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +""" +Unified CLI for Ninja Dependency Analysis and Selective Testing + +Features: +- Dependency parsing (from build.ninja) +- Selective test filtering (between git refs) +- Code auditing (--audit) +- Build optimization (--optimize-build) +""" + +import argparse +import sys +import os + +def run_dependency_parser(args): + from src.enhanced_ninja_parser import main as ninja_main + sys.argv = ["enhanced_ninja_parser.py"] + args + ninja_main() + +def run_selective_test_filter(args): + from src.selective_test_filter import main as filter_main + sys.argv = ["selective_test_filter.py"] + args + filter_main() + +def main(): + parser = argparse.ArgumentParser(description="Unified Ninja Dependency & Selective Testing Tool") + subparsers = parser.add_subparsers(dest="command", required=True) + + # Dependency parsing + parser_parse = subparsers.add_parser("parse", help="Parse build.ninja and generate dependency mapping") + parser_parse.add_argument("build_ninja", help="Path to build.ninja") + parser_parse.add_argument("--ninja", help="Path to ninja executable", default="ninja") + parser_parse.add_argument("--workspace-root", help="Path to workspace root", default=None) + + # Selective testing + parser_test = subparsers.add_parser("select", help="Selective test filtering between git refs") + parser_test.add_argument("depmap_json", help="Path to dependency mapping JSON") + parser_test.add_argument("ref1", help="Source git ref") + parser_test.add_argument("ref2", help="Target git ref") + parser_test.add_argument("--all", action="store_true", help="Include all executables") + parser_test.add_argument("--test-prefix", action="store_true", help="Only include executables starting with 'test_'") + parser_test.add_argument("--output", help="Output JSON file", default="tests_to_run.json") + + # Code auditing + parser_audit = subparsers.add_parser("audit", help="List all files and their dependent executables") + parser_audit.add_argument("depmap_json", help="Path to dependency mapping JSON") + + # Build optimization + parser_opt = subparsers.add_parser("optimize", help="List affected executables for changed files") + parser_opt.add_argument("depmap_json", help="Path to dependency mapping JSON") + parser_opt.add_argument("changed_files", nargs="+", help="List of changed files") + + args = parser.parse_args() + + if args.command == "parse": + parse_args = [args.build_ninja, args.ninja] + if args.workspace_root: + parse_args.append(args.workspace_root) + run_dependency_parser(parse_args) + elif args.command == "select": + filter_args = [args.depmap_json, args.ref1, args.ref2] + if args.test_prefix: + filter_args.append("--test-prefix") + if args.all: + filter_args.append("--all") + if args.output: + filter_args += ["--output", args.output] + run_selective_test_filter(filter_args) + elif args.command == "audit": + run_selective_test_filter([args.depmap_json, "--audit"]) + elif args.command == "optimize": + run_selective_test_filter([args.depmap_json, "--optimize-build"] + args.changed_files) + else: + parser.print_help() + +if __name__ == "__main__": + main() diff --git a/script/dependency-parser/src/enhanced_ninja_parser.py b/script/dependency-parser/src/enhanced_ninja_parser.py new file mode 100644 index 0000000000..087ab50640 --- /dev/null +++ b/script/dependency-parser/src/enhanced_ninja_parser.py @@ -0,0 +1,315 @@ +#!/usr/bin/env python3 +""" +Enhanced Ninja Dependency Parser + +This script combines ninja build file parsing with ninja -t deps to create a comprehensive +mapping that includes both source files AND header files, and properly handles files +used by multiple executables. +""" + +import re +import os +import sys +import subprocess +from pathlib import Path +from collections import defaultdict +import json +from concurrent.futures import ThreadPoolExecutor, as_completed +import threading + +class EnhancedNinjaDependencyParser: + def __init__(self, build_file_path, ninja_executable="ninja"): + self.build_file_path = build_file_path + self.build_dir = os.path.dirname(build_file_path) + self.ninja_executable = ninja_executable + + # Core data structures + self.executable_to_objects = {} # exe -> [object_files] + self.object_to_source = {} # object -> primary_source + self.object_to_all_deps = {} # object -> [all_dependencies] + self.file_to_executables = defaultdict(set) # file -> {executables} + + # Thread safety + self.lock = threading.Lock() + + def parse_dependencies(self): + """Main method to parse all dependencies.""" + print(f"Parsing ninja dependencies from: {self.build_file_path}") + + # Step 1: Parse build file for executable -> object mappings + self._parse_build_file() + + # Step 2: Get all object files and their dependencies + print(f"Found {len(self.object_to_source)} object files") + print("Extracting detailed dependencies for all object files...") + self._extract_object_dependencies() + + # Step 3: Build the final file -> executables mapping + self._build_file_to_executable_mapping() + + def _parse_build_file(self): + """Parse the ninja build file to extract executable -> object mappings.""" + print("Parsing ninja build file...") + + with open(self.build_file_path, 'r') as f: + content = f.read() + # Parse executable build rules + exe_pattern = r'^build (bin/[^:]+):\s+\S+\s+([^|]+)' + obj_pattern = r'^build ([^:]+\.(?:cpp|cu|hip)\.o):\s+\S+\s+([^\s|]+)' + + lines = content.split('\n') + + for line in lines: + # Match executable rules + exe_match = re.match(exe_pattern, line) + if exe_match and ('EXECUTABLE' in line or 'test_' in exe_match.group(1) or 'example_' in exe_match.group(1)): + exe = exe_match.group(1) + deps_part = exe_match.group(2).strip() + + object_files = [] + for dep in deps_part.split(): + if dep.endswith('.o') and not dep.startswith('/'): + object_files.append(dep) + + self.executable_to_objects[exe] = object_files + continue + + # Match object compilation rules + obj_match = re.match(obj_pattern, line) + if obj_match: + object_file = obj_match.group(1) + source_file = obj_match.group(2) + self.object_to_source[object_file] = source_file + + print(f"Found {len(self.executable_to_objects)} executables") + print(f"Found {len(self.object_to_source)} object-to-source mappings") + + def _extract_object_dependencies(self): + """Extract detailed dependencies for all object files using ninja -t deps.""" + object_files = list(self.object_to_source.keys()) + # Process object files in parallel for better performance + if not object_files: + print("No object files found - skipping dependency extraction") + return + + max_workers = min(16, len(object_files)) # Limit concurrent processes + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all object files for processing + future_to_obj = { + executor.submit(self._get_object_dependencies, obj): obj + for obj in object_files + } + # Process completed futures + completed = 0 + for future in as_completed(future_to_obj): + obj_file = future_to_obj[future] + try: + dependencies = future.result() + with self.lock: + self.object_to_all_deps[obj_file] = dependencies + completed += 1 + if completed % 100 == 0: + print(f"Processed {completed}/{len(object_files)} object files...") + except Exception as e: + print(f"Error processing {obj_file}: {e}") + + print(f"Completed dependency extraction for {len(self.object_to_all_deps)} object files") + + def _get_object_dependencies(self, object_file): + """Get all dependencies for a single object file using ninja -t deps.""" + try: + # Run ninja -t deps for this object file + cmd = [self.ninja_executable, "-t", "deps", object_file] + result = subprocess.run( + cmd, + cwd=self.build_dir, + capture_output=True, + text=True, + timeout=30 + ) + + if result.returncode != 0: + return [] + + dependencies = [] + lines = result.stdout.strip().split('\n') + + for line in lines[1:]: # Skip first line with metadata + line = line.strip() + if line and not line.startswith('#'): + # Convert absolute paths to relative paths from workspace root + dep_file = line + ws_root = getattr(self, "workspace_root", "..") + ws_prefix = ws_root.rstrip("/") + "/" + if dep_file.startswith(ws_prefix): + dep_file = dep_file[len(ws_prefix):] + dependencies.append(dep_file) + + return dependencies + + except Exception as e: + print(f"Error getting dependencies for {object_file}: {e}") + return [] + + def _build_file_to_executable_mapping(self): + """Build the final mapping from files to executables.""" + print("Building file-to-executable mapping...") + + for exe, object_files in self.executable_to_objects.items(): + for obj_file in object_files: + # Add all dependencies of this object file + if obj_file in self.object_to_all_deps: + for dep_file in self.object_to_all_deps[obj_file]: + # Filter out system files and focus on project files + if self._is_project_file(dep_file): + self.file_to_executables[dep_file].add(exe) + + print(f"Built mapping for {len(self.file_to_executables)} files") + + # Show statistics + multi_exe_files = {f: exes for f, exes in self.file_to_executables.items() if len(exes) > 1} + print(f"Files used by multiple executables: {len(multi_exe_files)}") + + if multi_exe_files: + print("Sample files with multiple dependencies:") + for f, exes in sorted(multi_exe_files.items())[:5]: + print(f" {f}: {len(exes)} executables") + + def _is_project_file(self, file_path): + """Determine if a file is part of the project (not system files).""" + # Include files that are clearly part of the project + if any(file_path.startswith(prefix) for prefix in [ + 'include/', 'library/', 'test/', 'example/', 'src/', 'profiler/', + 'build/include/', 'build/_deps/gtest', 'client_example', 'codegen', 'tile_engine' + ]): + return True + + # Exclude system files + if any(file_path.startswith(prefix) for prefix in [ + '/usr/', '/opt/rocm', '/lib/', '/system/', '/local/' + ]): + return False + + # Include files with common source/header extensions + if file_path.endswith(('.cpp', '.hpp', '.h', '.c', '.cc', '.cxx', '.cu', '.hip', '.inc')): + return True + + return False + + def export_to_csv(self, output_file): + """Export the file-to-executable mapping to CSV with proper comma separation.""" + print(f"Exporting mapping to {output_file}") + + with open(output_file, 'w') as f: + f.write("source_file,executables\n") + for file_path in sorted(self.file_to_executables.keys()): + executables = sorted(self.file_to_executables[file_path]) + # Use semicolon to separate multiple executables within the field + exe_list = ';'.join(executables) + f.write(f'"{file_path}","{exe_list}"\n') + + def export_to_json(self, output_file): + """Export the complete mapping to JSON.""" + print(f"Exporting complete mapping to {output_file}") + + # Build reverse mapping (executable -> files) + exe_to_files = defaultdict(set) + for file_path, exes in self.file_to_executables.items(): + for exe in exes: + exe_to_files[exe].add(file_path) + + mapping_data = { + 'file_to_executables': { + file_path: list(exes) for file_path, exes in self.file_to_executables.items() + }, + 'executable_to_files': { + exe: sorted(files) for exe, files in exe_to_files.items() + }, + 'statistics': { + 'total_files': len(self.file_to_executables), + 'total_executables': len(self.executable_to_objects), + 'total_object_files': len(self.object_to_source), + 'files_with_multiple_executables': len([f for f, exes in self.file_to_executables.items() if len(exes) > 1]) + } + } + + with open(output_file, 'w') as f: + json.dump(mapping_data, f, indent=2) + + def print_summary(self): + """Print a summary of the parsed dependencies.""" + print("\n=== Enhanced Dependency Mapping Summary ===") + print(f"Total executables: {len(self.executable_to_objects)}") + print(f"Total files mapped: {len(self.file_to_executables)}") + print(f"Total object files processed: {len(self.object_to_all_deps)}") + + # Files by type + cpp_files = sum(1 for f in self.file_to_executables.keys() if f.endswith('.cpp')) + hpp_files = sum(1 for f in self.file_to_executables.keys() if f.endswith('.hpp')) + h_files = sum(1 for f in self.file_to_executables.keys() if f.endswith('.h')) + + print(f"\nFile types:") + print(f" .cpp files: {cpp_files}") + print(f" .hpp files: {hpp_files}") + print(f" .h files: {h_files}") + + # Multi-executable files + multi_exe_files = {f: exes for f, exes in self.file_to_executables.items() if len(exes) > 1} + print(f"\nFiles used by multiple executables: {len(multi_exe_files)}") + + if multi_exe_files: + print("\nTop files with most dependencies:") + sorted_multi = sorted(multi_exe_files.items(), key=lambda x: len(x[1]), reverse=True) + for file_path, exes in sorted_multi[:10]: + print(f" {file_path}: {len(exes)} executables") + +def main(): + # Accept: build_file, ninja_path, workspace_root + default_workspace_root = ".." + if len(sys.argv) > 3: + build_file = sys.argv[1] + ninja_path = sys.argv[2] + workspace_root = sys.argv[3] + elif len(sys.argv) > 2: + build_file = sys.argv[1] + ninja_path = sys.argv[2] + workspace_root = default_workspace_root + elif len(sys.argv) > 1: + build_file = sys.argv[1] + ninja_path = "ninja" + workspace_root = default_workspace_root + else: + build_file = f"{default_workspace_root}/build/build.ninja" + ninja_path = "ninja" + workspace_root = default_workspace_root + + if not os.path.exists(build_file): + print(f"Error: Build file not found: {build_file}") + sys.exit(1) + + try: + subprocess.run([ninja_path, "--version"], capture_output=True, check=True) + except (subprocess.CalledProcessError, FileNotFoundError): + print(f"Error: ninja executable not found: {ninja_path}") + sys.exit(1) + + parser = EnhancedNinjaDependencyParser(build_file, ninja_path) + parser.workspace_root = workspace_root # Attach for use in _get_object_dependencies + parser.parse_dependencies() + parser.print_summary() + + # Export results + output_dir = os.path.dirname(build_file) + csv_file = os.path.join(output_dir, 'enhanced_file_executable_mapping.csv') + json_file = os.path.join(output_dir, 'enhanced_dependency_mapping.json') + + parser.export_to_csv(csv_file) + parser.export_to_json(json_file) + + print(f"\nResults exported to:") + print(f" CSV: {csv_file}") + print(f" JSON: {json_file}") + +if __name__ == "__main__": + main() diff --git a/script/dependency-parser/src/selective_test_filter.py b/script/dependency-parser/src/selective_test_filter.py new file mode 100644 index 0000000000..f364d60d27 --- /dev/null +++ b/script/dependency-parser/src/selective_test_filter.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +""" +Selective Test Filter Tool + +Given two git refs (branches or commit IDs), this tool: +- Identifies changed files between the refs +- Loads the enhanced dependency mapping JSON (from enhanced_ninja_parser.py) +- Maps changed files to affected test executables (optionally filtering for "test_" prefix) +- Exports the list of tests to run to tests_to_run.json + +Usage: + python selective_test_filter.py [--all | --test-prefix] [--output ] + +Arguments: + Path to enhanced_dependency_mapping.json + Source git ref (branch or commit) + Target git ref (branch or commit) + +Options: + --all Include all executables (default) + --test-prefix Only include executables starting with "test_" + --output Output JSON file (default: tests_to_run.json) +""" + +import sys +import subprocess +import json +import os + +def get_changed_files(ref1, ref2): + """Return a set of files changed between two git refs.""" + try: + result = subprocess.run( + ["git", "diff", "--name-only", ref1, ref2], + capture_output=True, text=True, check=True + ) + files = set(line.strip() for line in result.stdout.splitlines() if line.strip()) + return files + except subprocess.CalledProcessError as e: + print(f"Error running git diff: {e}") + sys.exit(1) + +def load_depmap(depmap_json): + """Load the dependency mapping JSON.""" + with open(depmap_json, "r") as f: + data = json.load(f) + # Support both old and new formats + if "file_to_executables" in data: + return data["file_to_executables"] + return data + +def select_tests(file_to_executables, changed_files, filter_mode): + """Return a set of test executables affected by changed files.""" + affected = set() + for f in changed_files: + if f in file_to_executables: + for exe in file_to_executables[f]: + if filter_mode == "all": + affected.add(exe) + elif filter_mode == "test_prefix" and exe.startswith("test_"): + affected.add(exe) + return sorted(affected) + +def main(): + if "--audit" in sys.argv: + if len(sys.argv) < 2: + print("Usage: python selective_test_filter.py --audit") + sys.exit(1) + depmap_json = sys.argv[1] + if not os.path.exists(depmap_json): + print(f"Dependency map JSON not found: {depmap_json}") + sys.exit(1) + file_to_executables = load_depmap(depmap_json) + for f, exes in file_to_executables.items(): + print(f"{f}: {', '.join(exes)}") + print(f"Total files: {len(file_to_executables)}") + sys.exit(0) + + if "--optimize-build" in sys.argv: + if len(sys.argv) < 3: + print("Usage: python selective_test_filter.py --optimize-build [ ...]") + sys.exit(1) + depmap_json = sys.argv[1] + changed_files = set(sys.argv[sys.argv.index("--optimize-build") + 1 :]) + if not os.path.exists(depmap_json): + print(f"Dependency map JSON not found: {depmap_json}") + sys.exit(1) + file_to_executables = load_depmap(depmap_json) + affected_executables = set() + for f in changed_files: + if f in file_to_executables: + affected_executables.update(file_to_executables[f]) + print("Affected executables:") + for exe in sorted(affected_executables): + print(exe) + print(f"Total affected executables: {len(affected_executables)}") + sys.exit(0) + + if len(sys.argv) < 4: + print("Usage: python selective_test_filter.py [--all | --test-prefix] [--output ]") + sys.exit(1) + + depmap_json = sys.argv[1] + ref1 = sys.argv[2] + ref2 = sys.argv[3] + filter_mode = "all" + output_json = "tests_to_run.json" + + if "--test-prefix" in sys.argv: + filter_mode = "test_prefix" + if "--all" in sys.argv: + filter_mode = "all" + if "--output" in sys.argv: + idx = sys.argv.index("--output") + if idx + 1 < len(sys.argv): + output_json = sys.argv[idx + 1] + + if not os.path.exists(depmap_json): + print(f"Dependency map JSON not found: {depmap_json}") + sys.exit(1) + + changed_files = get_changed_files(ref1, ref2) + if not changed_files: + print("No changed files detected.") + tests = [] + else: + file_to_executables = load_depmap(depmap_json) + tests = select_tests(file_to_executables, changed_files, filter_mode) + + with open(output_json, "w") as f: + json.dump({"tests_to_run": tests, "changed_files": sorted(changed_files)}, f, indent=2) + + print(f"Exported {len(tests)} tests to run to {output_json}") + +if __name__ == "__main__": + main() diff --git a/script/launch_tests.sh b/script/launch_tests.sh new file mode 100755 index 0000000000..829ac82378 --- /dev/null +++ b/script/launch_tests.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +# Get the directory where the script is located +BUILD_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Go one level up to PACKAGE_HOME +PACKAGE_HOME="$(dirname "$BUILD_DIR")" + +SCRIPT_DIR="$PACKAGE_HOME/script/" + +# Search for build.ninja under PACKAGE_HOME +BUILD_NINJA_FILE="$PACKAGE_HOME/build/build.ninja" + +if [ -z "$BUILD_NINJA_FILE" ]; then + echo "Error: build.ninja not found under $PACKAGE_HOME" + exit 1 +fi + +python3 "$SCRIPT_DIR/dependency-parser/main.py" parse "$BUILD_NINJA_FILE" --workspace-root "$PACKAGE_HOME" + +# Get the directory containing build.ninja +BUILD_DIR=$(dirname "$BUILD_NINJA_FILE") + +# Path to enhanced_dependency_mapping.json in the same directory +JSON_FILE="$BUILD_DIR/enhanced_dependency_mapping.json" + +# Check if the JSON file exists +if [ ! -f "$JSON_FILE" ]; then + echo "Error: $JSON_FILE not found." + exit 1 +fi + +branch=$(git rev-parse --abbrev-ref HEAD) + +# Run the command +python3 "$SCRIPT_DIR/dependency-parser/main.py" select "$JSON_FILE" origin/develop $branch + +# Path to tests_to_run.json in the same directory +TEST_FILE="tests_to_run.json" + +command=$(python3 -c " +import json +import os +with open('$TEST_FILE', 'r') as f: + data = json.load(f) + tests = data.get('tests_to_run', []) + if tests: + # Extract just the filename after the last '/' + clean_tests = [os.path.basename(test) for test in tests] + print('ctest -R \"' + '|'.join(clean_tests) + '\"') + else: + print('# No tests to run') +") + +echo "$command" + +eval "$command" + + diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index cfc5b0cd1a..8f880b8fde 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -1,9 +1,9 @@ # Currently ck_tile_gemm is only built on gfx94/gfx95 -set(EXAMPLE_GEMM_COMPILE_OPTIONS "") +set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -set(EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS "") +set(EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() diff --git a/test/ck_tile/gemm_multi_d/CMakeLists.txt b/test/ck_tile/gemm_multi_d/CMakeLists.txt index 1ec77eb87a..a50de7178b 100644 --- a/test/ck_tile/gemm_multi_d/CMakeLists.txt +++ b/test/ck_tile/gemm_multi_d/CMakeLists.txt @@ -1,4 +1,10 @@ # Currently ck_tile is only built on gfx9 -if(GPU_TARGETS MATCHES "gfx9") - add_gtest_executable(test_ck_tile_gemm_multi_d test_gemm_multi_d.cpp) +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") + add_gtest_executable(test_ck_tile_gemm_multi_d test_gemm_multi_d.cpp) + target_compile_definitions(test_ck_tile_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif()