From 75156c492ebef09b86491a660078b26ff3bc211d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Mon, 8 Dec 2025 10:32:56 +0100 Subject: [PATCH] [CK_BUILDER] Ck Tile Grouped convolution factory (#3352) * [BUILDER] Ck Tile Grouped convolution factory * Part 2 * Fixes after rebase * Remove leftovers [ROCm/composable_kernel commit: 04612c30ceab818cd6c03a3e833a6c6d1a21dafa] --- .../builder/conv_algorithm_concepts.hpp | 85 +++++++- .../ck_tile/builder/conv_algorithm_limits.hpp | 5 + .../builder/factory/conv_dispatcher.hpp | 29 ++- .../builder/factory/conv_fwd_dl_factory.hpp | 10 +- .../factory/conv_fwd_large_tensor_factory.hpp | 12 +- .../builder/factory/conv_fwd_v3_factory.hpp | 12 +- .../builder/factory/conv_fwd_wmma_factory.hpp | 12 +- .../builder/factory/conv_fwd_xdl_factory.hpp | 12 +- .../builder/factory/conv_tile_factory.hpp | 131 ++++++++++++ .../helpers/{ => ck}/conv_block_transfer.hpp | 0 .../helpers/{ => ck}/conv_elementwise_op.hpp | 0 .../helpers/{ => ck}/conv_tensor_layout.hpp | 0 .../helpers/{ => ck}/conv_tensor_type.hpp | 0 .../helpers/{ => ck}/conv_thread_block.hpp | 0 .../helpers/{ => ck}/conv_tuning_params.hpp | 0 .../ck_tile/conv_tile_block_transfer.hpp | 25 +++ .../ck_tile/conv_tile_elementwise_op.hpp | 62 ++++++ .../ck_tile/conv_tile_kernel_directions.hpp | 88 ++++++++ .../ck_tile/conv_tile_tensor_layout.hpp | 200 ++++++++++++++++++ .../helpers/ck_tile/conv_tile_tensor_type.hpp | 87 ++++++++ .../ck_tile/conv_tile_thread_block.hpp | 32 +++ .../ck_tile/conv_tile_tuning_params.hpp | 158 ++++++++++++++ .../builder/include/ck_tile/builder/types.hpp | 9 + experimental/builder/test/CMakeLists.txt | 31 +-- .../{ => ck}/test_ckb_conv_fwd_1d_bf16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_1d_fp16.cpp | 0 .../conv/{ => ck}/test_ckb_conv_fwd_1d_i8.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_bf16.cpp | 0 ...est_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_dl_fp16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_fp16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_fp32.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_fp8.cpp | 0 ...test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_3d_bf16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_3d_fp16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_3d_fp32.cpp | 0 .../test/conv/{ => ck}/test_conv_traits.cpp | 0 .../test_ckb_conv_bwd_data_2d_fp16_v3.cpp | 52 +++++ .../test_ckb_conv_bwd_weight_2d_fp16_v3.cpp | 52 +++++ .../ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp | 52 +++++ .../test/impl/conv_algorithm_types.hpp | 118 +++++++++++ .../builder/test/unit_conv_elementwise_op.cpp | 2 +- .../builder/test/unit_conv_tensor_layout.cpp | 2 +- .../builder/test/unit_conv_tensor_type.cpp | 2 +- .../builder/test/unit_conv_thread_block.cpp | 2 +- .../builder/test/unit_conv_tuning_params.cpp | 2 +- .../test/utils/ckb_conv_test_utils.hpp | 16 ++ .../test/utils/ckb_conv_tile_test_configs.hpp | 85 ++++++++ .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 4 +- .../gemm/pipeline/gemm_pipeline_problem.hpp | 7 +- ...ouped_convolution_backward_data_kernel.hpp | 17 +- ...ped_convolution_backward_weight_kernel.hpp | 37 ++-- .../grouped_convolution_forward_kernel.hpp | 36 ++-- .../utils/grouped_convolution_utils.hpp | 37 ++++ 55 files changed, 1431 insertions(+), 92 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_block_transfer.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_elementwise_op.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_tensor_layout.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_tensor_type.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_thread_block.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_tuning_params.hpp (100%) create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_1d_bf16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_1d_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_1d_i8.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_bf16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_dl_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_fp32.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_fp8.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_3d_bf16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_3d_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_3d_fp32.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_conv_traits.cpp (100%) create mode 100644 experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp create mode 100644 experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp create mode 100644 experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp create mode 100644 experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index ecb1ff933e..bf7e89fcaa 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -95,6 +95,47 @@ concept AccessOrderDescriptor = requires(T t) { { t.order } -> std::convertible_to>; }; +// Concept for thread block dimensions for a GEMM problem for CK Tile (Block +// size is deduced from block gemm structure). +template +concept TileThreadBlockDescriptor = requires(T t) { + { t.tile_size.m } -> std::convertible_to; + { t.tile_size.n } -> std::convertible_to; + { t.tile_size.k } -> std::convertible_to; +}; + +// Concept for thread block dimensions for a GEMM problem for CK Tile (Block +// size is deduced from block gemm structure). +template +concept TileTransferDescriptor = requires(T t) { + { t.a_scalar_per_vector } -> std::convertible_to; + { t.b_scalar_per_vector } -> std::convertible_to; + { t.c_scalar_per_vector } -> std::convertible_to; +}; + +// Concept to check if struct specifies block GEMM (CK Tile). +template +concept TileBlockGemmDescriptor = requires(T t) { + { t.warps.m } -> std::convertible_to; + { t.warps.n } -> std::convertible_to; + { t.warps.k } -> std::convertible_to; + { t.warp_tile.m } -> std::convertible_to; + { t.warp_tile.n } -> std::convertible_to; + { t.warp_tile.k } -> std::convertible_to; + { t.double_smem_buffer } -> std::convertible_to; + { t.num_wave_groups } -> std::convertible_to; + { t.pipeline_version } -> std::convertible_to; + { t.scheduler } -> std::convertible_to; +}; + +// Concept to check if struct specifies optimizations (CK Tile). +template +concept TileOptimizationsDescriptor = requires(T t) { + { t.num_groups_to_merge } -> std::convertible_to; + { t.split_image } -> std::convertible_to; + { t.explicit_gemm } -> std::convertible_to; +}; + // Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this // concept. template @@ -110,6 +151,12 @@ concept SpecifiesThreadBlock = requires { { T::thread_block } -> ThreadBlockDescriptor; }; +// Concept to check if struct specifies thread block info (CK Tile). +template +concept SpecifiesTileThreadBlock = requires { + { T::thread_block } -> TileThreadBlockDescriptor; +}; + // Concept to check if a struct specifies gridwise XDL GEMM info. template concept SpecifiesGridwiseXdlGemm = requires { @@ -130,6 +177,14 @@ concept SpecifiesBlockTransfer = requires(T t) { { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; +// Concept to check if a struct specifies convolution scalar per vector infor for A, B and C. +template +concept SpecifiesTileTransfer = requires(T t) { + { T::transfer.a_scalar_per_vector } -> std::convertible_to; + { T::transfer.b_scalar_per_vector } -> std::convertible_to; + { T::transfer.c_scalar_per_vector } -> std::convertible_to; +}; + // Concept to check if a struct specifies LDS transfer info for tensors A, B, and C. template concept SpecifiesLdsTransfer = requires(T t) { @@ -159,8 +214,36 @@ concept SpecifiesBlockGemm = requires { { T::block_gemm.scheduler } -> std::convertible_to; }; +// Concept to check if struct specifies block GEMM (CK Tile). template -concept SpecifiesFwdConcSpecialization = requires { +concept SpecifiesTileBlockGemm = requires { + { T::block_gemm.warps.m } -> std::convertible_to; + { T::block_gemm.warps.n } -> std::convertible_to; + { T::block_gemm.warps.k } -> std::convertible_to; + { T::block_gemm.warp_tile.m } -> std::convertible_to; + { T::block_gemm.warp_tile.n } -> std::convertible_to; + { T::block_gemm.warp_tile.k } -> std::convertible_to; + { T::block_gemm.double_smem_buffer } -> std::convertible_to; + { T::block_gemm.num_wave_groups } -> std::convertible_to; + { T::block_gemm.pipeline_version } -> std::convertible_to; + { T::block_gemm.scheduler } -> std::convertible_to; +}; + +// Concept to check if struct specifies block GEMM (CK Tile). +template +concept SpecifiesTileOptimizations = requires { + { T::optimizations.num_groups_to_merge } -> std::convertible_to; + { T::optimizations.split_image } -> std::convertible_to; + { T::optimizations.explicit_gemm } -> std::convertible_to; +}; + +template +concept SpecifiesTileConvSpecialization = requires { + { T::specialization } -> std::convertible_to; +}; + +template +concept SpecifiesFwdConvSpecialization = requires { { T::fwd_specialization } -> std::convertible_to; }; diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp index 093916dac3..10a619024a 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp @@ -15,6 +15,11 @@ concept InputVectorTransferLimits = requires { Value.lds_dst_scalar_per_vector > 0; }; +// Limits for input and output vector transfer (CK Tile). +template +concept TileInputOutputVectorTransferLimits = + requires { requires Value.a > 0 && Value.b > 0 && Value.c > 0; }; + // Limits for output vector transfer. template concept OutputVectorTransferLimits = requires { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 51945544b2..9a9c2235e0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -59,6 +59,7 @@ #include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp" +#include "ck_tile/builder/factory/conv_tile_factory.hpp" namespace ck_tile::builder::factory { @@ -81,6 +82,15 @@ namespace ck_tile::builder::factory { // // TODO: Make this dispatch logic much more robust and clear for users. +// CK Tile kernel +template +consteval bool IsTileAlgorithm() +{ + return ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && SpecifiesTileTransfer && + SpecifiesTileConvSpecialization && SpecifiesTileBlockGemm && + SpecifiesTileOptimizations; +} + // XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline) template consteval bool IsXdlV3Algorithm() @@ -88,7 +98,7 @@ consteval bool IsXdlV3Algorithm() return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesBlockGemm; } @@ -99,7 +109,7 @@ consteval bool IsXdlAlgorithm() return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; } @@ -111,7 +121,7 @@ consteval bool IsWmmaAlgorithm() return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; } @@ -120,7 +130,7 @@ template consteval bool IsDlAlgorithm() { return ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; } @@ -137,10 +147,15 @@ template constexpr auto make_conv_instance() { - if constexpr(ConvDirectionIsForward) - { - using AlgoType = std::remove_const_t; + using AlgoType = std::remove_const_t; + // CK Tile supports common factory for each direction + if constexpr(IsTileAlgorithm()) + { + return typename ConvTileFactory::Instance{}; + } + else if constexpr(ConvDirectionIsForward) + { if constexpr(IsXdlV3Algorithm()) { return typename ConvFwdXdlV3Factory::Instance{}; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp index 0c675ac7f1..ca202aabfd 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp @@ -7,11 +7,11 @@ #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index 98e368ca61..fadf41f48a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 79955a1f44..89787cc1b3 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index fcce46aea7..bb84479071 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index df7fb25168..8ec5c633ce 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp new file mode 100644 index 0000000000..cce95cb3f1 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp @@ -0,0 +1,131 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp" + +namespace ck_tile::builder::factory { + +// Factory for CK Tile Grouped Convolution kernels. +template +struct ConvTileFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::TileConvTensorLayouts; + using Types = internal::TileConvTensorTypes; + using Ops = internal::TileElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto CONV_SPECIALIZATION = internal::SetTileConvSpecialization(); + static constexpr auto BLOCK = internal::SetTileThreadBlockInfo(); + static constexpr auto BLOCK_GEMM = internal::SetTileBlockGemm(); + static constexpr auto OPTIMIZATIONS = internal::SetTileOptimizations(); + static constexpr auto SCALAR_PER_VECTOR = internal::SetTileBlockTransfer(); + static constexpr auto CONV_DIRECTION = internal::SetTileConvDirection(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(TileInputOutputVectorTransferLimits); + + using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, + BLOCK_GEMM.double_smem_buffer, + typename GroupedConvTraitsType::template GemmLayouts::AsLayout, + typename GroupedConvTraitsType::template GemmLayouts::BsLayout, + typename GroupedConvTraitsType::template GemmLayouts::CLayout, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, + BLOCK_GEMM.num_wave_groups>; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + typename Types::ADataType, + typename Types::BDataType, + typename Types::AccDataType, + GemmShape, + GemmUniversalTraits, + BLOCK_GEMM.scheduler, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Types::EDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; + + using GemmPipeline = typename internal::TilePipelineType< + BLOCK_GEMM.pipeline_version>::template GemmPipeline; + + using ConvEpilogue = ck_tile::CShuffleEpilogue>; + + using Instance = typename internal::GroupedConvolutionTileKernel::Instance; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_block_transfer.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_elementwise_op.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_layout.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_type.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_thread_block.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_thread_block.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_tuning_params.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp new file mode 100644 index 0000000000..fbeb48b045 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp @@ -0,0 +1,25 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_algorithm_concepts.hpp" + +namespace ck_tile::builder::factory::internal { + +struct TileScalarPerVector +{ + size_t a = 0; + size_t b = 0; + size_t c = 0; +}; + +template +constexpr TileScalarPerVector SetTileBlockTransfer() +{ + return TileScalarPerVector{.a = ALGORITHM.transfer.a_scalar_per_vector, + .b = ALGORITHM.transfer.b_scalar_per_vector, + .c = ALGORITHM.transfer.c_scalar_per_vector}; +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp new file mode 100644 index 0000000000..45ff7d265d --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp @@ -0,0 +1,62 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/types.hpp" + +namespace ck_tile::builder::factory::internal { + +template +struct ElementwiseOpToCKTile +{ + static_assert(sizeof(UnsupportedEnumValue) == 0, + "Unsupported elementwise operation conversion to CK."); +}; + +template <> +struct ElementwiseOpToCKTile +{ + using Op = ck_tile::element_wise::PassThrough; +}; + +template <> +struct ElementwiseOpToCKTile +{ + using Op = ck_tile::element_wise::Scale; +}; + +template <> +struct ElementwiseOpToCKTile +{ + using Op = ck_tile::element_wise::Clamp; +}; + +template +consteval auto GetTileElementwiseOp() +{ + if constexpr(HasTensorOp) + { + constexpr auto op = TensorDesc.operation.elementwise_operation; + return ElementwiseOpToCKTile{}; + } + else + { + return ElementwiseOpToCKTile{}; + } +} + +template +struct TileElementwiseOps +{ + static constexpr auto input_op = GetTileElementwiseOp(); + static constexpr auto weight_op = GetTileElementwiseOp(); + static constexpr auto output_op = GetTileElementwiseOp(); + using AElementwiseOp = typename decltype(input_op)::Op; + using BElementwiseOp = typename decltype(weight_op)::Op; + using CDEElementwiseOp = typename decltype(output_op)::Op; +}; + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp new file mode 100644 index 0000000000..189b199ffc --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp @@ -0,0 +1,88 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_signature_concepts.hpp" + +namespace ck_tile::builder::factory::internal { + +template +struct GroupedConvolutionTileKernel +{ + static_assert(false, "Unknown Direction"); +}; + +template + requires ConvDirectionIsForward +struct GroupedConvolutionTileKernel +{ + using Instance = ck_tile::GroupedConvolutionForwardKernel; +}; + +template + requires ConvDirectionIsBackwardData +struct GroupedConvolutionTileKernel +{ + using Instance = ck_tile::GroupedConvolutionBackwardDataKernel; +}; + +template + requires ConvDirectionIsBackwardWeight +struct GroupedConvolutionTileKernel +{ + using Instance = ck_tile::GroupedConvolutionBackwardWeightKernel; +}; + +template +consteval ck_tile::GroupedConvDirection SetTileConvDirection() +{ + constexpr auto direction = SIGNATURE.direction; + using ck_tile_direction = ck_tile::GroupedConvDirection; + switch(direction) + { + case ConvDirection::FORWARD: return ck_tile_direction::FORWARD; + case ConvDirection::BACKWARD_DATA: return ck_tile_direction::BACKWARD_DATA; + case ConvDirection::BACKWARD_WEIGHT: return ck_tile_direction::BACKWARD_WEIGHT; + default: throw "Unknown Direction"; + } +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp new file mode 100644 index 0000000000..2aaca98586 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp @@ -0,0 +1,200 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" + +namespace ck_tile::builder::factory::internal { +using ALayout = ck_tile::tensor_layout::convolution::NWGC; +template +struct LayoutToCKTile +{ + static_assert(sizeof(UnsupportedEnumValue) == 0, + "Unsupported layout conversion to CK."); +}; + +// Bias layouts +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::G_K; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::G_C; +}; + +// Input 1D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NWGC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNWC; +}; + +// Input 2D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NHWGC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNHWC; +}; + +// Input 3D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NDHWGC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNDHWC; +}; + +// Weight 1D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKXC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKCX; +}; + +// Weight 2D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKYXC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKCYX; +}; + +// Weight 3D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKCZYX; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKZYXC; +}; + +// Output 1D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NWGK; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNWK; +}; + +// Output 2D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NHWGK; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNHWK; +}; + +// Output 3D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NDHWGK; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNDHWK; +}; + +template +consteval auto TensorLayoutToCKTile() +{ + return typename LayoutToCKTile::type{}; +} + +struct EmptyAuxiliaryTileTensorLayout +{ + using type = ck_tile::tuple<>; +}; + +template +consteval auto GetAuxiliaryTileTensorLayoutTuple(std::index_sequence) +{ + return ck_tile::tuple< + decltype(TensorLayoutToCKTile())...>{}; +} + +template + requires(ConvSpatialDim) +struct AuxiliaryTileTensorLayouts +{ + static constexpr auto Size = AuxiliaryTileTensorConfigsValue.size(); + using type = decltype(GetAuxiliaryTileTensorLayoutTuple( + std::make_index_sequence{})); +}; + +// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias). +template + requires(HasElementwiseOpWithAuxiliaryOperands) +consteval auto GetAuxiliaryTileTensorLayouts() +{ + return AuxiliaryTileTensorLayouts{}; +} + +template + requires(!HasElementwiseOpWithAuxiliaryOperands) +consteval auto GetAuxiliaryTileTensorLayouts() +{ + return EmptyAuxiliaryTileTensorLayout{}; +} + +template + requires(ConvSpatialDim && + ValidConvInputLayoutForSpatialDim && + ValidConvWeightLayoutForSpatialDim && + ValidConvOutputLayoutForSpatialDim) +struct TileConvTensorLayouts +{ + using ALayout = decltype(TensorLayoutToCKTile()); + using BLayout = decltype(TensorLayoutToCKTile()); + using ELayout = decltype(TensorLayoutToCKTile()); + using DsLayout = decltype(GetAuxiliaryTileTensorLayouts())::type; +}; + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp new file mode 100644 index 0000000000..493fbb7d9b --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp @@ -0,0 +1,87 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/builder/types.hpp" +#include "ck_tile/builder/builder_utils.hpp" + +namespace ck_tile::builder::factory::internal { + +// Type mappings from builder convolution data type to CK Tile tensor types. +template +struct TileConvTensorTypes +{ + // This will trigger if a specialization for the given DataType is not found. + // We should always catch this in an earlier validation check. + static_assert(sizeof(UnsupportedEnumValue) == 0, + "Internal error. Unsupported data type for convolution factory."); +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = ck_tile::half_t; + using AComputeType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using BComputeType = ck_tile::half_t; + using CShuffleDataType = ck_tile::half_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = ck_tile::half_t; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = ck_tile::bf16_t; + using AComputeType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using BComputeType = ck_tile::bf16_t; + using CShuffleDataType = ck_tile::bf16_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = ck_tile::bf16_t; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = float; + using AComputeType = float; + using BDataType = float; + using BComputeType = float; + using CShuffleDataType = float; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = float; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = int8_t; + using AComputeType = int8_t; + using BDataType = int8_t; + using BComputeType = int8_t; + using CShuffleDataType = int8_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = int32_t; + using EDataType = int8_t; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = ck_tile::fp8_t; + using AComputeType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using BComputeType = ck_tile::fp8_t; + using CShuffleDataType = ck_tile::fp8_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = ck_tile::fp8_t; +}; + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp new file mode 100644 index 0000000000..65d81a49c4 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp @@ -0,0 +1,32 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_algorithm_concepts.hpp" + +namespace ck_tile::builder::factory::internal { + +// Convenience struct for a tuple of m, n, and k values. +struct TileBlockMNK +{ + int m{}; + int n{}; + int k{}; +}; + +struct TileConvBlock +{ + TileBlockMNK per_block = {}; +}; + +template +constexpr TileConvBlock SetTileThreadBlockInfo() +{ + constexpr auto& TB = ALGORITHM.thread_block; + return TileConvBlock{ + .per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k}, + }; +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp new file mode 100644 index 0000000000..b7df0e4d0e --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp @@ -0,0 +1,158 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/types.hpp" + +namespace ck_tile::builder::factory::internal { + +// Convenience struct for a tuple of m, n, and k values. +struct TileBlockGemmMNK +{ + int m{}; + int n{}; + int k{}; +}; + +struct TileBlockGemmSpec +{ + TileBlockGemmMNK warps = {}; + TileBlockGemmMNK warp_tile = {}; + + bool double_smem_buffer = false; + int num_wave_groups = 1; + + ck_tile::GemmPipeline pipeline_version; + ck_tile::GemmPipelineScheduler scheduler; +}; + +struct TileOptimizations +{ + int num_groups_to_merge = 1; + bool split_image = false; + bool explicit_gemm = false; +}; + +template +consteval ck_tile::GemmPipelineScheduler SetTileScheduler() +{ + constexpr auto scheduler = ALGORITHM.block_gemm.scheduler; + using ck_tile_sched = ck_tile::GemmPipelineScheduler; + switch(scheduler) + { + case PipelineScheduler::DEFAULT: return ck_tile_sched::Default; + case PipelineScheduler::INTERWAVE: return ck_tile_sched::Interwave; + case PipelineScheduler::INTRAWAVE: return ck_tile_sched::Intrawave; + default: throw "Unknown PipelineScheduler"; + } +} + +template +struct TilePipelineType +{ + static_assert(false, "Unknown PipelineScheduler"); +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; +}; + +template +consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion() +{ + constexpr auto version = ALGORITHM.block_gemm.pipeline_version; + using ck_tile_pipeline = ck_tile::GemmPipeline; + switch(version) + { + case PipelineVersion::V1: return ck_tile_pipeline::BASIC_V1; + case PipelineVersion::V2: return ck_tile_pipeline::MEMORY; + case PipelineVersion::V3: return ck_tile_pipeline::COMPUTE_V3; + case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4; + case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5; + case PipelineVersion::WEIGHT_ONLY: + throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version."; + default: throw "Unknown block GEMM PipelineVersion"; + } +} + +template +consteval ck_tile::ConvolutionSpecialization SetTileConvSpecialization() +{ + constexpr auto specialization = ALGORITHM.specialization; + using ck_tile_conv_spec = ck_tile::ConvolutionSpecialization; + switch(specialization) + { + case TileConvSpecialization::DEFAULT: return ck_tile_conv_spec::Default; + case TileConvSpecialization::FILTER_1X1_PAD0: return ck_tile_conv_spec::Filter1x1Pad0; + case TileConvSpecialization::FILTER_1X1_STRIDE1_PAD0: + return ck_tile_conv_spec::Filter1x1Stride1Pad0; + case TileConvSpecialization::FILTER_3x3: return ck_tile_conv_spec::Filter3x3; + default: throw "Unknown ConvFwdSpecialization"; + } +} + +template +consteval TileBlockGemmSpec SetTileBlockGemm() +{ + constexpr auto& BG = ALGORITHM.block_gemm; + + constexpr bool double_smem_buffer = BG.double_smem_buffer; + constexpr int num_wave_groups = BG.num_wave_groups; + + constexpr ck_tile::GemmPipeline pipeline_version = SetTileBlockGemmPipelineVersion(); + constexpr ck_tile::GemmPipelineScheduler scheduler = SetTileScheduler(); + + return TileBlockGemmSpec{ + .warps = {.m = BG.warps.m, .n = BG.warps.n, .k = BG.warps.k}, + .warp_tile = {.m = BG.warp_tile.m, .n = BG.warp_tile.n, .k = BG.warp_tile.k}, + .double_smem_buffer = double_smem_buffer, + .num_wave_groups = num_wave_groups, + .pipeline_version = pipeline_version, + .scheduler = scheduler}; +} + +template +consteval TileOptimizations SetTileOptimizations() +{ + constexpr auto& OPT = ALGORITHM.optimizations; + + return TileOptimizations{.num_groups_to_merge = OPT.num_groups_to_merge, + .split_image = OPT.split_image, + .explicit_gemm = OPT.explicit_gemm}; +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 565bb98528..532d8a1882 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -145,6 +145,15 @@ enum struct GemmSpecialization MNKOPadding }; +// Enums for the CK Tile convolution specialization. +enum class TileConvSpecialization +{ + DEFAULT, + FILTER_1X1_PAD0, + FILTER_1X1_STRIDE1_PAD0, + FILTER_3x3 +}; + // Enums for the forward convolution specialization. enum class ConvFwdSpecialization { diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index a340a789de..eef1110d27 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -90,7 +90,7 @@ add_ck_builder_test(test_ckb_conv_builder # Tests convolution trait selection and configuration add_ck_builder_test(test_ckb_conv_traits - conv/test_conv_traits.cpp) + conv/ck/test_conv_traits.cpp) # Tests convolution problem description and parameter handling add_ck_builder_test(test_ckb_conv_description @@ -119,19 +119,22 @@ add_ck_builder_test(test_ckb_instance_string # Tests the forward convolution builder across multiple data types and dimensions. # Individual tests are split into separate files to enable parallel compilation. add_ck_builder_test(test_ckb_build_fwd_instances - conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp - conv/test_ckb_conv_fwd_1d_fp16.cpp - conv/test_ckb_conv_fwd_1d_bf16.cpp - conv/test_ckb_conv_fwd_1d_i8.cpp - conv/test_ckb_conv_fwd_2d_fp8.cpp - conv/test_ckb_conv_fwd_2d_bf16.cpp - conv/test_ckb_conv_fwd_2d_fp16.cpp - conv/test_ckb_conv_fwd_2d_fp32.cpp - conv/test_ckb_conv_fwd_2d_dl_fp16.cpp - conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp - conv/test_ckb_conv_fwd_3d_bf16.cpp - conv/test_ckb_conv_fwd_3d_fp16.cpp - conv/test_ckb_conv_fwd_3d_fp32.cpp + conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp + conv/ck/test_ckb_conv_fwd_1d_fp16.cpp + conv/ck/test_ckb_conv_fwd_1d_bf16.cpp + conv/ck/test_ckb_conv_fwd_1d_i8.cpp + conv/ck/test_ckb_conv_fwd_2d_fp8.cpp + conv/ck/test_ckb_conv_fwd_2d_bf16.cpp + conv/ck/test_ckb_conv_fwd_2d_fp16.cpp + conv/ck/test_ckb_conv_fwd_2d_fp32.cpp + conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp + conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp + conv/ck/test_ckb_conv_fwd_3d_bf16.cpp + conv/ck/test_ckb_conv_fwd_3d_fp16.cpp + conv/ck/test_ckb_conv_fwd_3d_fp32.cpp + conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp + conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp + conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp ) diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp diff --git a/experimental/builder/test/conv/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp similarity index 100% rename from experimental/builder/test/conv/test_conv_traits.cpp rename to experimental/builder/test/conv/ck/test_conv_traits.cpp diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp new file mode 100644 index 0000000000..ad31fc52bc --- /dev/null +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_tile_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { + +using namespace ck_tile::builder::test_utils; + +TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +{ + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::BACKWARD_DATA, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(TileConvSpecialization::DEFAULT) + .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_optimizations(TileOptimizations{ + .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + + using Builder = ConvBuilder; + run_ck_tile_test({ + "grouped_convolution_backward_data", + "fp16", + "NHWGC_GKYXC_NHWGK", + "64x64x64", + "2x2", + "16x16x16", + // "4x4x4", // TODO: Enable this check + "Default", + "Intrawave", + "CShuffleEpilogue", + "set", + "pipeline_AgBgCrCompV3", + "DoubleSmemBuffer_0", + "NumWaveGroups_1", + "MergedGroups_1", + "SplitImage_0", + "ExplicitGemm_0", + }); +} + +} // namespace diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp new file mode 100644 index 0000000000..47908e0e5b --- /dev/null +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_tile_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { + +using namespace ck_tile::builder::test_utils; + +TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +{ + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::BACKWARD_WEIGHT, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(TileConvSpecialization::DEFAULT) + .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_optimizations(TileOptimizations{ + .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + + using Builder = ConvBuilder; + run_ck_tile_test({ + "grouped_convolution_backward_weight", + "fp16", + "NHWGC_GKYXC_NHWGK", + "64x64x64", + "2x2", + "16x16x16", + // "4x4x4", // TODO: Enable this check + "Default", + "Intrawave", + "CShuffleEpilogue", + "set", + "pipeline_AgBgCrCompV3", + "DoubleSmemBuffer_0", + "NumWaveGroups_1", + "MergedGroups_1", + "SplitImage_0", + "ExplicitGemm_0", + }); +} + +} // namespace diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp new file mode 100644 index 0000000000..083d9d9955 --- /dev/null +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_tile_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { + +using namespace ck_tile::builder::test_utils; + +TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +{ + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(TileConvSpecialization::DEFAULT) + .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_optimizations(TileOptimizations{ + .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + + using Builder = ConvBuilder; + run_ck_tile_test({ + "grouped_convolution_forward", + "fp16", + "NHWGC_GKYXC_NHWGK", + "64x64x64", + "2x2", + "16x16x16", + // "4x4x4", // TODO: Enable this check + "Default", + "Intrawave", + "CShuffleEpilogue", + "set", + "pipeline_AgBgCrCompV3", + "DoubleSmemBuffer_0", + "NumWaveGroups_1", + "MergedGroups_1", + "SplitImage_0", + "ExplicitGemm_0", + }); +} + +} // namespace diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index d89d83357f..29c7f3cdcc 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -243,6 +243,73 @@ struct LargeTensorWrapper ConvAlgorithmSpecialization::LARGE_TENSOR; }; +// Specify thread block dimensions for a GEMM (CK Tile). +struct TileThreadBlock +{ + // Size of the submatrix problem in a thread block. + MNK tile_size; +}; +static_assert(ckb::TileThreadBlockDescriptor); + +struct TileTransfer +{ + size_t a_scalar_per_vector; + size_t b_scalar_per_vector; + size_t c_scalar_per_vector; +}; +static_assert(ckb::TileTransferDescriptor); + +struct TileBlockGemm +{ + // Number of warps per each dimension. + MNK warps; + // Number of data processed per each dimension for each XDL/WMMA instruction. + MNK warp_tile; + // Double LDS buffer. + bool double_smem_buffer; + // Waves grouping (Ping-Pong scheduler). + int num_wave_groups; + PipelineVersion pipeline_version; + PipelineScheduler scheduler; +}; +static_assert(ckb::TileBlockGemmDescriptor); + +struct TileOptimizations +{ + // Number of convolution groups processed per one workgroup + int num_groups_to_merge; + // Split image for large tensors + bool split_image; + // Explicit gemm for 1x1, stride=0, pad=0 cases + bool explicit_gemm; +}; +static_assert(ckb::TileOptimizationsDescriptor); + +struct TileConvSpecialization_ +{ + TileConvSpecialization specialization; +}; + +struct TileThreadBlock_ +{ + TileThreadBlock thread_block; +}; + +struct TileTransfer_ +{ + TileTransfer transfer; +}; + +struct TileBlockGemm_ +{ + TileBlockGemm block_gemm; +}; + +struct TileOptimizations_ +{ + TileOptimizations optimizations; +}; + // Factory template @@ -339,6 +406,51 @@ struct ConvAlgorithmTemplate : Components... result.transfer = t; return result; } + + template + constexpr auto with_tile_specializations(const S& s) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.specialization = s; + return result; + } + + template + constexpr auto with_tile_thread_block(const TB& tb) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.thread_block = tb; + return result; + } + + template + constexpr auto with_tile_block_gemm(const BG& bg) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.block_gemm = bg; + return result; + } + + template + constexpr auto with_tile_transfer(const T& t) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.transfer = t; + return result; + } + + template + constexpr auto with_tile_optimizations(const O& o) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.optimizations = o; + return result; + } }; // Algorithm types @@ -361,4 +473,10 @@ using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = LargeTensorWrapper; +using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate; + } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/unit_conv_elementwise_op.cpp b/experimental/builder/test/unit_conv_elementwise_op.cpp index 84a9c533f6..610edd281e 100644 --- a/experimental/builder/test/unit_conv_elementwise_op.cpp +++ b/experimental/builder/test/unit_conv_elementwise_op.cpp @@ -4,7 +4,7 @@ #include #include -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_tensor_layout.cpp b/experimental/builder/test/unit_conv_tensor_layout.cpp index 7764e94dc6..26df33cc8d 100644 --- a/experimental/builder/test/unit_conv_tensor_layout.cpp +++ b/experimental/builder/test/unit_conv_tensor_layout.cpp @@ -4,7 +4,7 @@ #include #include -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" #include "impl/conv_signature_types.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_tensor_type.cpp b/experimental/builder/test/unit_conv_tensor_type.cpp index c92b24626e..7ffd446966 100644 --- a/experimental/builder/test/unit_conv_tensor_type.cpp +++ b/experimental/builder/test/unit_conv_tensor_type.cpp @@ -4,7 +4,7 @@ #include #include -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_thread_block.cpp b/experimental/builder/test/unit_conv_thread_block.cpp index f829708696..ce5a772cfa 100644 --- a/experimental/builder/test/unit_conv_thread_block.cpp +++ b/experimental/builder/test/unit_conv_thread_block.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT #include -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_tuning_params.cpp b/experimental/builder/test/unit_conv_tuning_params.cpp index 82117c53d8..b35a1ced55 100644 --- a/experimental/builder/test/unit_conv_tuning_params.cpp +++ b/experimental/builder/test/unit_conv_tuning_params.cpp @@ -3,7 +3,7 @@ #include -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" namespace { diff --git a/experimental/builder/test/utils/ckb_conv_test_utils.hpp b/experimental/builder/test/utils/ckb_conv_test_utils.hpp index 508c621c2e..1acf170455 100644 --- a/experimental/builder/test/utils/ckb_conv_test_utils.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_utils.hpp @@ -28,4 +28,20 @@ constexpr void run_test(const std::vector& kernel_instance_componen } } +// Common CK Tile test implementation +template +constexpr void run_ck_tile_test(const std::vector& kernel_instance_components) +{ + auto instance = typename Builder::Instance{}; + + const auto kernel_string = instance.GetTypeString(); + std::cout << "Generated kernel: " << kernel_string << std::endl; + EXPECT_GT(kernel_string.size(), 0); + std::cout << kernel_string << std::endl; + for(const auto& component : kernel_instance_components) + { + EXPECT_THAT(kernel_string, ::testing::HasSubstr(component)); + } +} + } // namespace ck_tile::builder::test_utils diff --git a/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp new file mode 100644 index 0000000000..377234dd19 --- /dev/null +++ b/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp @@ -0,0 +1,85 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "impl/conv_algorithm_types.hpp" +#include "impl/conv_signature_types.hpp" +#include "ck_tile/builder/conv_builder.hpp" + +namespace ck_tile::builder::test_utils { + +using namespace ck_tile::builder; +using namespace test; + +constexpr TileTransfer FwdTileTransfer_1x1x1{ + .a_scalar_per_vector = 1, + .b_scalar_per_vector = 1, + .c_scalar_per_vector = 1, +}; + +constexpr TileTransfer FwdTileTransfer_4x4x4{ + .a_scalar_per_vector = 4, + .b_scalar_per_vector = 4, + .c_scalar_per_vector = 4, +}; + +constexpr TileTransfer FwdTileTransfer_8x8x8{ + .a_scalar_per_vector = 8, + .b_scalar_per_vector = 8, + .c_scalar_per_vector = 8, +}; + +constexpr TileThreadBlock FwdTileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}}; + +constexpr TileThreadBlock FwdTileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v1_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V1, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v2_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V2, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v3_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V3, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v4_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V4, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v5_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V5, + .scheduler = PipelineScheduler::INTRAWAVE}; + +} // namespace ck_tile::builder::test_utils diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index d4475e8c60..8fae704203 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -176,8 +176,10 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); return concat('_', "pipeline_AgBgCrCompV3", concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize, + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), concat('x', WaveNumM, WaveNumN), - concat('x', kPadM, kPadN, kPadK)); + concat('x', kPadM, kPadN, kPadK), + Problem::GetName()); // clang-format on } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 2c6b1f3d48..e35f4ce70d 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -301,7 +301,12 @@ struct UniversalGemmPipelineProblem return concat('_', "gemm_problem", concat('x', kBlockSize), concat('x', kPadM, kPadN, kPadK), - Scheduler); + Scheduler, + "NumWaveGroups", + NumWaveGroups, + "DoubleSmemBuffer", + DoubleSmemBuffer + ); // clang-format on } }; diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index e172e732fa..46c60cb6d7 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -560,16 +560,31 @@ struct GroupedConvolutionBackwardDataKernel [[nodiscard]] CK_TILE_HOST static const std::string GetName() { + static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage; + constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; // clang-format off return concat('_', "grouped_convolution_backward_data", gemm_prec_str(), + InLayout::name, + WeiLayout::name, + OutLayout::name, "gemm", GemmPipeline::GetName(), "epilogue", - EpiloguePipeline::GetName()); + EpiloguePipeline::GetName(), + getConvSpecializationString(ConvSpecialization), + "MergedGroups", + NumGroupsToMerge, + "SplitImage", + EnableSplitImage, + "ExplicitGemm", + GroupedConvTraitsType_::ExplicitGemm + ); // clang-format on } + [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); } + #ifdef CK_EXPERIMENTAL_BUILDER CK_TILE_HOST std::string GetInstanceString() const { diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 6ef1d84a6e..f43bfdacac 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -417,26 +417,31 @@ struct GroupedConvolutionBackwardWeightKernel [[nodiscard]] CK_TILE_HOST static const std::string GetName() { - constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; + static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage; + constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; // clang-format off - if (NumGroupsToMerge > 1) { - return concat('_', "grouped_convolution_backward_weight", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName()); - } else { - return concat('_', "grouped_convolution_backward_weight", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName(), "merge", NumGroupsToMerge); - } + return concat('_', "grouped_convolution_backward_weight", + gemm_prec_str(), + InLayout::name, + WeiLayout::name, + OutLayout::name, + "gemm", + GemmPipeline::GetName(), + "epilogue", + EpiloguePipeline::GetName(), + getConvSpecializationString(ConvSpecialization), + "MergedGroups", + NumGroupsToMerge, + "SplitImage", + EnableSplitImage, + "ExplicitGemm", + GroupedConvTraitsType_::ExplicitGemm + ); // clang-format on } + [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); } + #ifdef CK_EXPERIMENTAL_BUILDER CK_TILE_HOST std::string GetInstanceString() const { diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 72ba17c5a5..a9f3274805 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -594,26 +594,28 @@ struct GroupedConvolutionForwardKernel { constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; // clang-format off - if (NumGroupsToMerge > 1) { - return concat('_', "grouped_convolution_forward", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName(), - "merge", - NumGroupsToMerge); - } else { - return concat('_', "grouped_convolution_forward", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName()); - } + return concat('_', "grouped_convolution_forward", + gemm_prec_str(), + InLayout::name, + WeiLayout::name, + OutLayout::name, + "gemm", + GemmPipeline::GetName(), + "epilogue", + EpiloguePipeline::GetName(), + getConvSpecializationString(ConvSpecialization), + "MergedGroups", + NumGroupsToMerge, + "SplitImage", + EnableSplitImage, + "ExplicitGemm", + GroupedConvTraitsType_::ExplicitGemm + ); // clang-format on } + [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); } + #ifdef CK_EXPERIMENTAL_BUILDER CK_TILE_HOST std::string GetInstanceString() const { diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 71739c9083..5b00e53af8 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -9,6 +9,13 @@ namespace ck_tile { +enum class GroupedConvDirection +{ + FORWARD, + BACKWARD_DATA, + BACKWARD_WEIGHT +}; + /// @brief The Grouped Conv kernel host arguments. /// /// @par Overview @@ -113,6 +120,36 @@ struct GroupedConvTraits using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor; using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor; + template + struct GemmLayouts + { + static_assert(false, "Unsupported direction."); + }; + + template <> + struct GemmLayouts + { + using AsLayout = AsLayoutFwd; + using BsLayout = BsLayoutFwd; + using CLayout = CLayoutFwd; + }; + + template <> + struct GemmLayouts + { + using AsLayout = AsLayoutBwdData; + using BsLayout = BsLayoutBwdData; + using CLayout = CLayoutBwdData; + }; + + template <> + struct GemmLayouts + { + using AsLayout = AsLayoutBwdWeight; + using BsLayout = BsLayoutBwdWeight; + using CLayout = CLayoutBwdWeight; + }; + template using GroupedConvImplicitGemmTraitsFwd = TileGemmTraits;