From e5a32772612a311109eddc1f77d79ec0e248d52a Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Mon, 8 Dec 2025 11:12:53 +0000 Subject: [PATCH] Merge commit '04612c30ceab818cd6c03a3e833a6c6d1a21dafa' into develop --- CHANGELOG.md | 1 + CMakeLists.txt | 17 ++ README.md | 2 +- client_example/CMakeLists.txt | 12 ++ .../builder/conv_algorithm_concepts.hpp | 85 +++++++- .../ck_tile/builder/conv_algorithm_limits.hpp | 5 + .../builder/factory/conv_dispatcher.hpp | 29 ++- .../builder/factory/conv_fwd_dl_factory.hpp | 10 +- .../factory/conv_fwd_large_tensor_factory.hpp | 12 +- .../builder/factory/conv_fwd_v3_factory.hpp | 12 +- .../builder/factory/conv_fwd_wmma_factory.hpp | 12 +- .../builder/factory/conv_fwd_xdl_factory.hpp | 12 +- .../builder/factory/conv_tile_factory.hpp | 131 ++++++++++++ .../helpers/{ => ck}/conv_block_transfer.hpp | 0 .../helpers/{ => ck}/conv_elementwise_op.hpp | 0 .../helpers/{ => ck}/conv_tensor_layout.hpp | 0 .../helpers/{ => ck}/conv_tensor_type.hpp | 0 .../helpers/{ => ck}/conv_thread_block.hpp | 0 .../helpers/{ => ck}/conv_tuning_params.hpp | 0 .../ck_tile/conv_tile_block_transfer.hpp | 25 +++ .../ck_tile/conv_tile_elementwise_op.hpp | 62 ++++++ .../ck_tile/conv_tile_kernel_directions.hpp | 88 ++++++++ .../ck_tile/conv_tile_tensor_layout.hpp | 200 ++++++++++++++++++ .../helpers/ck_tile/conv_tile_tensor_type.hpp | 87 ++++++++ .../ck_tile/conv_tile_thread_block.hpp | 32 +++ .../ck_tile/conv_tile_tuning_params.hpp | 158 ++++++++++++++ .../builder/include/ck_tile/builder/types.hpp | 9 + experimental/builder/test/CMakeLists.txt | 31 +-- .../{ => ck}/test_ckb_conv_fwd_1d_bf16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_1d_fp16.cpp | 0 .../conv/{ => ck}/test_ckb_conv_fwd_1d_i8.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_bf16.cpp | 0 ...est_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_dl_fp16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_fp16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_fp32.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_fp8.cpp | 0 ...test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_3d_bf16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_3d_fp16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_3d_fp32.cpp | 0 .../test/conv/{ => ck}/test_conv_traits.cpp | 0 .../test_ckb_conv_bwd_data_2d_fp16_v3.cpp | 52 +++++ .../test_ckb_conv_bwd_weight_2d_fp16_v3.cpp | 52 +++++ .../ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp | 52 +++++ .../test/impl/conv_algorithm_types.hpp | 118 +++++++++++ .../builder/test/unit_conv_elementwise_op.cpp | 2 +- .../builder/test/unit_conv_tensor_layout.cpp | 2 +- .../builder/test/unit_conv_tensor_type.cpp | 2 +- .../builder/test/unit_conv_thread_block.cpp | 2 +- .../builder/test/unit_conv_tuning_params.cpp | 2 +- .../test/utils/ckb_conv_test_utils.hpp | 16 ++ .../test/utils/ckb_conv_tile_test_configs.hpp | 85 ++++++++ include/ck/config.h.in | 11 + .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 4 +- .../gemm/pipeline/gemm_pipeline_problem.hpp | 7 +- ...ouped_convolution_backward_data_kernel.hpp | 17 +- ...ped_convolution_backward_weight_kernel.hpp | 37 ++-- .../grouped_convolution_forward_kernel.hpp | 36 ++-- .../utils/grouped_convolution_utils.hpp | 37 ++++ .../gpu/grouped_convolution_backward_data.hpp | 16 +- ...ped_convolution_backward_data_bilinear.hpp | 20 +- ...rouped_convolution_backward_data_scale.hpp | 20 +- .../grouped_convolution_backward_weight.hpp | 16 +- ...d_convolution_backward_weight_bilinear.hpp | 11 +- ...uped_convolution_backward_weight_scale.hpp | 10 +- .../gpu/grouped_convolution_forward.hpp | 17 +- ...d_convolution_forward_bias_bnorm_clamp.hpp | 16 +- ...grouped_convolution_forward_bias_clamp.hpp | 18 +- .../grouped_convolution_forward_bilinear.hpp | 10 +- .../gpu/grouped_convolution_forward_clamp.hpp | 17 +- .../gpu/grouped_convolution_forward_scale.hpp | 10 +- .../gpu/CMakeLists.txt | 21 +- .../src/profile_grouped_conv_bwd_data.cpp | 18 -- .../src/profile_grouped_conv_bwd_weight.cpp | 16 -- profiler/src/profile_grouped_conv_fwd.cpp | 20 -- .../profile_grouped_conv_fwd_bias_clamp.cpp | 6 - .../src/profile_grouped_conv_fwd_clamp.cpp | 6 - test/CMakeLists.txt | 6 + 79 files changed, 1608 insertions(+), 232 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_block_transfer.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_elementwise_op.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_tensor_layout.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_tensor_type.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_thread_block.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_tuning_params.hpp (100%) create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_1d_bf16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_1d_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_1d_i8.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_bf16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_dl_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_fp32.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_fp8.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_3d_bf16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_3d_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_3d_fp32.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_conv_traits.cpp (100%) create mode 100644 experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp create mode 100644 experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp create mode 100644 experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp create mode 100644 experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index a50303113d..15fdb09f49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Added * Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. +* Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32". ### Changed diff --git a/CMakeLists.txt b/CMakeLists.txt index 9d0c4d79f9..acae1f5ece 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,6 +92,10 @@ if (DTYPES) add_definitions(-DCK_ENABLE_FP32) set(CK_ENABLE_FP32 "ON") endif() + if (DTYPES MATCHES "tf32") + # definition will be added based on the GPU target in the following section + set(CK_ENABLE_TF32 "ON") + endif() if (DTYPES MATCHES "fp64") add_definitions(-DCK_ENABLE_FP64) set(CK_ENABLE_FP64 "ON") @@ -106,6 +110,7 @@ else() set(CK_ENABLE_INT8 "ON") set(CK_ENABLE_FP16 "ON") set(CK_ENABLE_FP32 "ON") + set(CK_ENABLE_TF32 "ON") set(CK_ENABLE_FP64 "ON") set(CK_ENABLE_BF16 "ON") set(CK_ENABLE_FP8 "ON") @@ -282,6 +287,15 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx950") set(CK_GFX950_SUPPORT "ON") endif() +if ((SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32) + add_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "ON") +else() + message(STATUS "Disabling TF32 instances") + remove_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "OFF") +endif() + option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH) @@ -651,6 +665,9 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32") set(add_inst 1) endif() + if(("${cmake_instance}" MATCHES "tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32") + set(add_inst 1) + endif() if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64") set(add_inst 1) endif() diff --git a/README.md b/README.md index 01d523c2ab..8a5258bab6 100644 --- a/README.md +++ b/README.md @@ -187,7 +187,7 @@ limit the number of threads. For example, if you have a 128-core CPU and 128 Gb Additional cmake flags can be used to significantly speed-up the build: -* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build +* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;tf32;fp16;fp8;bf16;int8" to build instances of select data types only. The main default data types are fp32 and fp16; you can safely skip other data types. diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index 2ed338d08a..cab84f5c6c 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -27,6 +27,9 @@ if (DTYPES) add_definitions(-DCK_ENABLE_FP32) set(CK_ENABLE_FP32 "ON") endif() + if (DTYPES MATCHES "tf32") + set(CK_ENABLE_TF32 "ON") + endif() if (DTYPES MATCHES "fp64") add_definitions(-DCK_ENABLE_FP64) set(CK_ENABLE_FP64 "ON") @@ -41,6 +44,7 @@ else() set(CK_ENABLE_INT8 "ON") set(CK_ENABLE_FP16 "ON") set(CK_ENABLE_FP32 "ON") + set(CK_ENABLE_TF32 "ON") set(CK_ENABLE_FP64 "ON") set(CK_ENABLE_BF16 "ON") if (GPU_TARGETS MATCHES "gfx94") @@ -67,6 +71,14 @@ if (GPU_TARGETS) add_definitions(-DCK_USE_FNUZ_FP8) set(CK_USE_FNUZ_FP8 "ON") endif() + if ((GPU_TARGETS MATCHES "gfx942" OR GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32) + add_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "ON") + else() + message(STATUS "Disabling TF32 instances for this target") + remove_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "OFF") + endif() else() add_definitions(-DCK_USE_WMMA -DCK_USE_XDL) set(CK_USE_XDL "ON") diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index ecb1ff933e..bf7e89fcaa 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -95,6 +95,47 @@ concept AccessOrderDescriptor = requires(T t) { { t.order } -> std::convertible_to>; }; +// Concept for thread block dimensions for a GEMM problem for CK Tile (Block +// size is deduced from block gemm structure). +template +concept TileThreadBlockDescriptor = requires(T t) { + { t.tile_size.m } -> std::convertible_to; + { t.tile_size.n } -> std::convertible_to; + { t.tile_size.k } -> std::convertible_to; +}; + +// Concept for thread block dimensions for a GEMM problem for CK Tile (Block +// size is deduced from block gemm structure). +template +concept TileTransferDescriptor = requires(T t) { + { t.a_scalar_per_vector } -> std::convertible_to; + { t.b_scalar_per_vector } -> std::convertible_to; + { t.c_scalar_per_vector } -> std::convertible_to; +}; + +// Concept to check if struct specifies block GEMM (CK Tile). +template +concept TileBlockGemmDescriptor = requires(T t) { + { t.warps.m } -> std::convertible_to; + { t.warps.n } -> std::convertible_to; + { t.warps.k } -> std::convertible_to; + { t.warp_tile.m } -> std::convertible_to; + { t.warp_tile.n } -> std::convertible_to; + { t.warp_tile.k } -> std::convertible_to; + { t.double_smem_buffer } -> std::convertible_to; + { t.num_wave_groups } -> std::convertible_to; + { t.pipeline_version } -> std::convertible_to; + { t.scheduler } -> std::convertible_to; +}; + +// Concept to check if struct specifies optimizations (CK Tile). +template +concept TileOptimizationsDescriptor = requires(T t) { + { t.num_groups_to_merge } -> std::convertible_to; + { t.split_image } -> std::convertible_to; + { t.explicit_gemm } -> std::convertible_to; +}; + // Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this // concept. template @@ -110,6 +151,12 @@ concept SpecifiesThreadBlock = requires { { T::thread_block } -> ThreadBlockDescriptor; }; +// Concept to check if struct specifies thread block info (CK Tile). +template +concept SpecifiesTileThreadBlock = requires { + { T::thread_block } -> TileThreadBlockDescriptor; +}; + // Concept to check if a struct specifies gridwise XDL GEMM info. template concept SpecifiesGridwiseXdlGemm = requires { @@ -130,6 +177,14 @@ concept SpecifiesBlockTransfer = requires(T t) { { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; +// Concept to check if a struct specifies convolution scalar per vector infor for A, B and C. +template +concept SpecifiesTileTransfer = requires(T t) { + { T::transfer.a_scalar_per_vector } -> std::convertible_to; + { T::transfer.b_scalar_per_vector } -> std::convertible_to; + { T::transfer.c_scalar_per_vector } -> std::convertible_to; +}; + // Concept to check if a struct specifies LDS transfer info for tensors A, B, and C. template concept SpecifiesLdsTransfer = requires(T t) { @@ -159,8 +214,36 @@ concept SpecifiesBlockGemm = requires { { T::block_gemm.scheduler } -> std::convertible_to; }; +// Concept to check if struct specifies block GEMM (CK Tile). template -concept SpecifiesFwdConcSpecialization = requires { +concept SpecifiesTileBlockGemm = requires { + { T::block_gemm.warps.m } -> std::convertible_to; + { T::block_gemm.warps.n } -> std::convertible_to; + { T::block_gemm.warps.k } -> std::convertible_to; + { T::block_gemm.warp_tile.m } -> std::convertible_to; + { T::block_gemm.warp_tile.n } -> std::convertible_to; + { T::block_gemm.warp_tile.k } -> std::convertible_to; + { T::block_gemm.double_smem_buffer } -> std::convertible_to; + { T::block_gemm.num_wave_groups } -> std::convertible_to; + { T::block_gemm.pipeline_version } -> std::convertible_to; + { T::block_gemm.scheduler } -> std::convertible_to; +}; + +// Concept to check if struct specifies block GEMM (CK Tile). +template +concept SpecifiesTileOptimizations = requires { + { T::optimizations.num_groups_to_merge } -> std::convertible_to; + { T::optimizations.split_image } -> std::convertible_to; + { T::optimizations.explicit_gemm } -> std::convertible_to; +}; + +template +concept SpecifiesTileConvSpecialization = requires { + { T::specialization } -> std::convertible_to; +}; + +template +concept SpecifiesFwdConvSpecialization = requires { { T::fwd_specialization } -> std::convertible_to; }; diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp index 093916dac3..10a619024a 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp @@ -15,6 +15,11 @@ concept InputVectorTransferLimits = requires { Value.lds_dst_scalar_per_vector > 0; }; +// Limits for input and output vector transfer (CK Tile). +template +concept TileInputOutputVectorTransferLimits = + requires { requires Value.a > 0 && Value.b > 0 && Value.c > 0; }; + // Limits for output vector transfer. template concept OutputVectorTransferLimits = requires { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 51945544b2..9a9c2235e0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -59,6 +59,7 @@ #include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp" +#include "ck_tile/builder/factory/conv_tile_factory.hpp" namespace ck_tile::builder::factory { @@ -81,6 +82,15 @@ namespace ck_tile::builder::factory { // // TODO: Make this dispatch logic much more robust and clear for users. +// CK Tile kernel +template +consteval bool IsTileAlgorithm() +{ + return ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && SpecifiesTileTransfer && + SpecifiesTileConvSpecialization && SpecifiesTileBlockGemm && + SpecifiesTileOptimizations; +} + // XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline) template consteval bool IsXdlV3Algorithm() @@ -88,7 +98,7 @@ consteval bool IsXdlV3Algorithm() return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesBlockGemm; } @@ -99,7 +109,7 @@ consteval bool IsXdlAlgorithm() return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; } @@ -111,7 +121,7 @@ consteval bool IsWmmaAlgorithm() return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; } @@ -120,7 +130,7 @@ template consteval bool IsDlAlgorithm() { return ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; } @@ -137,10 +147,15 @@ template constexpr auto make_conv_instance() { - if constexpr(ConvDirectionIsForward) - { - using AlgoType = std::remove_const_t; + using AlgoType = std::remove_const_t; + // CK Tile supports common factory for each direction + if constexpr(IsTileAlgorithm()) + { + return typename ConvTileFactory::Instance{}; + } + else if constexpr(ConvDirectionIsForward) + { if constexpr(IsXdlV3Algorithm()) { return typename ConvFwdXdlV3Factory::Instance{}; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp index 0c675ac7f1..ca202aabfd 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp @@ -7,11 +7,11 @@ #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index 98e368ca61..fadf41f48a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 79955a1f44..89787cc1b3 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index fcce46aea7..bb84479071 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index df7fb25168..8ec5c633ce 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp new file mode 100644 index 0000000000..cce95cb3f1 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp @@ -0,0 +1,131 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp" + +namespace ck_tile::builder::factory { + +// Factory for CK Tile Grouped Convolution kernels. +template +struct ConvTileFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::TileConvTensorLayouts; + using Types = internal::TileConvTensorTypes; + using Ops = internal::TileElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto CONV_SPECIALIZATION = internal::SetTileConvSpecialization(); + static constexpr auto BLOCK = internal::SetTileThreadBlockInfo(); + static constexpr auto BLOCK_GEMM = internal::SetTileBlockGemm(); + static constexpr auto OPTIMIZATIONS = internal::SetTileOptimizations(); + static constexpr auto SCALAR_PER_VECTOR = internal::SetTileBlockTransfer(); + static constexpr auto CONV_DIRECTION = internal::SetTileConvDirection(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(TileInputOutputVectorTransferLimits); + + using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, + BLOCK_GEMM.double_smem_buffer, + typename GroupedConvTraitsType::template GemmLayouts::AsLayout, + typename GroupedConvTraitsType::template GemmLayouts::BsLayout, + typename GroupedConvTraitsType::template GemmLayouts::CLayout, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, + BLOCK_GEMM.num_wave_groups>; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + typename Types::ADataType, + typename Types::BDataType, + typename Types::AccDataType, + GemmShape, + GemmUniversalTraits, + BLOCK_GEMM.scheduler, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Types::EDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; + + using GemmPipeline = typename internal::TilePipelineType< + BLOCK_GEMM.pipeline_version>::template GemmPipeline; + + using ConvEpilogue = ck_tile::CShuffleEpilogue>; + + using Instance = typename internal::GroupedConvolutionTileKernel::Instance; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_block_transfer.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_elementwise_op.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_layout.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_type.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_thread_block.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_thread_block.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_tuning_params.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp new file mode 100644 index 0000000000..fbeb48b045 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp @@ -0,0 +1,25 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_algorithm_concepts.hpp" + +namespace ck_tile::builder::factory::internal { + +struct TileScalarPerVector +{ + size_t a = 0; + size_t b = 0; + size_t c = 0; +}; + +template +constexpr TileScalarPerVector SetTileBlockTransfer() +{ + return TileScalarPerVector{.a = ALGORITHM.transfer.a_scalar_per_vector, + .b = ALGORITHM.transfer.b_scalar_per_vector, + .c = ALGORITHM.transfer.c_scalar_per_vector}; +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp new file mode 100644 index 0000000000..45ff7d265d --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp @@ -0,0 +1,62 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/types.hpp" + +namespace ck_tile::builder::factory::internal { + +template +struct ElementwiseOpToCKTile +{ + static_assert(sizeof(UnsupportedEnumValue) == 0, + "Unsupported elementwise operation conversion to CK."); +}; + +template <> +struct ElementwiseOpToCKTile +{ + using Op = ck_tile::element_wise::PassThrough; +}; + +template <> +struct ElementwiseOpToCKTile +{ + using Op = ck_tile::element_wise::Scale; +}; + +template <> +struct ElementwiseOpToCKTile +{ + using Op = ck_tile::element_wise::Clamp; +}; + +template +consteval auto GetTileElementwiseOp() +{ + if constexpr(HasTensorOp) + { + constexpr auto op = TensorDesc.operation.elementwise_operation; + return ElementwiseOpToCKTile{}; + } + else + { + return ElementwiseOpToCKTile{}; + } +} + +template +struct TileElementwiseOps +{ + static constexpr auto input_op = GetTileElementwiseOp(); + static constexpr auto weight_op = GetTileElementwiseOp(); + static constexpr auto output_op = GetTileElementwiseOp(); + using AElementwiseOp = typename decltype(input_op)::Op; + using BElementwiseOp = typename decltype(weight_op)::Op; + using CDEElementwiseOp = typename decltype(output_op)::Op; +}; + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp new file mode 100644 index 0000000000..189b199ffc --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp @@ -0,0 +1,88 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_signature_concepts.hpp" + +namespace ck_tile::builder::factory::internal { + +template +struct GroupedConvolutionTileKernel +{ + static_assert(false, "Unknown Direction"); +}; + +template + requires ConvDirectionIsForward +struct GroupedConvolutionTileKernel +{ + using Instance = ck_tile::GroupedConvolutionForwardKernel; +}; + +template + requires ConvDirectionIsBackwardData +struct GroupedConvolutionTileKernel +{ + using Instance = ck_tile::GroupedConvolutionBackwardDataKernel; +}; + +template + requires ConvDirectionIsBackwardWeight +struct GroupedConvolutionTileKernel +{ + using Instance = ck_tile::GroupedConvolutionBackwardWeightKernel; +}; + +template +consteval ck_tile::GroupedConvDirection SetTileConvDirection() +{ + constexpr auto direction = SIGNATURE.direction; + using ck_tile_direction = ck_tile::GroupedConvDirection; + switch(direction) + { + case ConvDirection::FORWARD: return ck_tile_direction::FORWARD; + case ConvDirection::BACKWARD_DATA: return ck_tile_direction::BACKWARD_DATA; + case ConvDirection::BACKWARD_WEIGHT: return ck_tile_direction::BACKWARD_WEIGHT; + default: throw "Unknown Direction"; + } +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp new file mode 100644 index 0000000000..2aaca98586 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp @@ -0,0 +1,200 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" + +namespace ck_tile::builder::factory::internal { +using ALayout = ck_tile::tensor_layout::convolution::NWGC; +template +struct LayoutToCKTile +{ + static_assert(sizeof(UnsupportedEnumValue) == 0, + "Unsupported layout conversion to CK."); +}; + +// Bias layouts +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::G_K; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::G_C; +}; + +// Input 1D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NWGC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNWC; +}; + +// Input 2D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NHWGC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNHWC; +}; + +// Input 3D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NDHWGC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNDHWC; +}; + +// Weight 1D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKXC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKCX; +}; + +// Weight 2D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKYXC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKCYX; +}; + +// Weight 3D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKCZYX; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKZYXC; +}; + +// Output 1D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NWGK; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNWK; +}; + +// Output 2D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NHWGK; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNHWK; +}; + +// Output 3D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NDHWGK; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNDHWK; +}; + +template +consteval auto TensorLayoutToCKTile() +{ + return typename LayoutToCKTile::type{}; +} + +struct EmptyAuxiliaryTileTensorLayout +{ + using type = ck_tile::tuple<>; +}; + +template +consteval auto GetAuxiliaryTileTensorLayoutTuple(std::index_sequence) +{ + return ck_tile::tuple< + decltype(TensorLayoutToCKTile())...>{}; +} + +template + requires(ConvSpatialDim) +struct AuxiliaryTileTensorLayouts +{ + static constexpr auto Size = AuxiliaryTileTensorConfigsValue.size(); + using type = decltype(GetAuxiliaryTileTensorLayoutTuple( + std::make_index_sequence{})); +}; + +// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias). +template + requires(HasElementwiseOpWithAuxiliaryOperands) +consteval auto GetAuxiliaryTileTensorLayouts() +{ + return AuxiliaryTileTensorLayouts{}; +} + +template + requires(!HasElementwiseOpWithAuxiliaryOperands) +consteval auto GetAuxiliaryTileTensorLayouts() +{ + return EmptyAuxiliaryTileTensorLayout{}; +} + +template + requires(ConvSpatialDim && + ValidConvInputLayoutForSpatialDim && + ValidConvWeightLayoutForSpatialDim && + ValidConvOutputLayoutForSpatialDim) +struct TileConvTensorLayouts +{ + using ALayout = decltype(TensorLayoutToCKTile()); + using BLayout = decltype(TensorLayoutToCKTile()); + using ELayout = decltype(TensorLayoutToCKTile()); + using DsLayout = decltype(GetAuxiliaryTileTensorLayouts())::type; +}; + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp new file mode 100644 index 0000000000..493fbb7d9b --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp @@ -0,0 +1,87 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/builder/types.hpp" +#include "ck_tile/builder/builder_utils.hpp" + +namespace ck_tile::builder::factory::internal { + +// Type mappings from builder convolution data type to CK Tile tensor types. +template +struct TileConvTensorTypes +{ + // This will trigger if a specialization for the given DataType is not found. + // We should always catch this in an earlier validation check. + static_assert(sizeof(UnsupportedEnumValue) == 0, + "Internal error. Unsupported data type for convolution factory."); +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = ck_tile::half_t; + using AComputeType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using BComputeType = ck_tile::half_t; + using CShuffleDataType = ck_tile::half_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = ck_tile::half_t; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = ck_tile::bf16_t; + using AComputeType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using BComputeType = ck_tile::bf16_t; + using CShuffleDataType = ck_tile::bf16_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = ck_tile::bf16_t; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = float; + using AComputeType = float; + using BDataType = float; + using BComputeType = float; + using CShuffleDataType = float; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = float; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = int8_t; + using AComputeType = int8_t; + using BDataType = int8_t; + using BComputeType = int8_t; + using CShuffleDataType = int8_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = int32_t; + using EDataType = int8_t; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = ck_tile::fp8_t; + using AComputeType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using BComputeType = ck_tile::fp8_t; + using CShuffleDataType = ck_tile::fp8_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = ck_tile::fp8_t; +}; + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp new file mode 100644 index 0000000000..65d81a49c4 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp @@ -0,0 +1,32 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_algorithm_concepts.hpp" + +namespace ck_tile::builder::factory::internal { + +// Convenience struct for a tuple of m, n, and k values. +struct TileBlockMNK +{ + int m{}; + int n{}; + int k{}; +}; + +struct TileConvBlock +{ + TileBlockMNK per_block = {}; +}; + +template +constexpr TileConvBlock SetTileThreadBlockInfo() +{ + constexpr auto& TB = ALGORITHM.thread_block; + return TileConvBlock{ + .per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k}, + }; +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp new file mode 100644 index 0000000000..b7df0e4d0e --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp @@ -0,0 +1,158 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/types.hpp" + +namespace ck_tile::builder::factory::internal { + +// Convenience struct for a tuple of m, n, and k values. +struct TileBlockGemmMNK +{ + int m{}; + int n{}; + int k{}; +}; + +struct TileBlockGemmSpec +{ + TileBlockGemmMNK warps = {}; + TileBlockGemmMNK warp_tile = {}; + + bool double_smem_buffer = false; + int num_wave_groups = 1; + + ck_tile::GemmPipeline pipeline_version; + ck_tile::GemmPipelineScheduler scheduler; +}; + +struct TileOptimizations +{ + int num_groups_to_merge = 1; + bool split_image = false; + bool explicit_gemm = false; +}; + +template +consteval ck_tile::GemmPipelineScheduler SetTileScheduler() +{ + constexpr auto scheduler = ALGORITHM.block_gemm.scheduler; + using ck_tile_sched = ck_tile::GemmPipelineScheduler; + switch(scheduler) + { + case PipelineScheduler::DEFAULT: return ck_tile_sched::Default; + case PipelineScheduler::INTERWAVE: return ck_tile_sched::Interwave; + case PipelineScheduler::INTRAWAVE: return ck_tile_sched::Intrawave; + default: throw "Unknown PipelineScheduler"; + } +} + +template +struct TilePipelineType +{ + static_assert(false, "Unknown PipelineScheduler"); +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; +}; + +template +consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion() +{ + constexpr auto version = ALGORITHM.block_gemm.pipeline_version; + using ck_tile_pipeline = ck_tile::GemmPipeline; + switch(version) + { + case PipelineVersion::V1: return ck_tile_pipeline::BASIC_V1; + case PipelineVersion::V2: return ck_tile_pipeline::MEMORY; + case PipelineVersion::V3: return ck_tile_pipeline::COMPUTE_V3; + case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4; + case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5; + case PipelineVersion::WEIGHT_ONLY: + throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version."; + default: throw "Unknown block GEMM PipelineVersion"; + } +} + +template +consteval ck_tile::ConvolutionSpecialization SetTileConvSpecialization() +{ + constexpr auto specialization = ALGORITHM.specialization; + using ck_tile_conv_spec = ck_tile::ConvolutionSpecialization; + switch(specialization) + { + case TileConvSpecialization::DEFAULT: return ck_tile_conv_spec::Default; + case TileConvSpecialization::FILTER_1X1_PAD0: return ck_tile_conv_spec::Filter1x1Pad0; + case TileConvSpecialization::FILTER_1X1_STRIDE1_PAD0: + return ck_tile_conv_spec::Filter1x1Stride1Pad0; + case TileConvSpecialization::FILTER_3x3: return ck_tile_conv_spec::Filter3x3; + default: throw "Unknown ConvFwdSpecialization"; + } +} + +template +consteval TileBlockGemmSpec SetTileBlockGemm() +{ + constexpr auto& BG = ALGORITHM.block_gemm; + + constexpr bool double_smem_buffer = BG.double_smem_buffer; + constexpr int num_wave_groups = BG.num_wave_groups; + + constexpr ck_tile::GemmPipeline pipeline_version = SetTileBlockGemmPipelineVersion(); + constexpr ck_tile::GemmPipelineScheduler scheduler = SetTileScheduler(); + + return TileBlockGemmSpec{ + .warps = {.m = BG.warps.m, .n = BG.warps.n, .k = BG.warps.k}, + .warp_tile = {.m = BG.warp_tile.m, .n = BG.warp_tile.n, .k = BG.warp_tile.k}, + .double_smem_buffer = double_smem_buffer, + .num_wave_groups = num_wave_groups, + .pipeline_version = pipeline_version, + .scheduler = scheduler}; +} + +template +consteval TileOptimizations SetTileOptimizations() +{ + constexpr auto& OPT = ALGORITHM.optimizations; + + return TileOptimizations{.num_groups_to_merge = OPT.num_groups_to_merge, + .split_image = OPT.split_image, + .explicit_gemm = OPT.explicit_gemm}; +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 565bb98528..532d8a1882 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -145,6 +145,15 @@ enum struct GemmSpecialization MNKOPadding }; +// Enums for the CK Tile convolution specialization. +enum class TileConvSpecialization +{ + DEFAULT, + FILTER_1X1_PAD0, + FILTER_1X1_STRIDE1_PAD0, + FILTER_3x3 +}; + // Enums for the forward convolution specialization. enum class ConvFwdSpecialization { diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index a340a789de..eef1110d27 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -90,7 +90,7 @@ add_ck_builder_test(test_ckb_conv_builder # Tests convolution trait selection and configuration add_ck_builder_test(test_ckb_conv_traits - conv/test_conv_traits.cpp) + conv/ck/test_conv_traits.cpp) # Tests convolution problem description and parameter handling add_ck_builder_test(test_ckb_conv_description @@ -119,19 +119,22 @@ add_ck_builder_test(test_ckb_instance_string # Tests the forward convolution builder across multiple data types and dimensions. # Individual tests are split into separate files to enable parallel compilation. add_ck_builder_test(test_ckb_build_fwd_instances - conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp - conv/test_ckb_conv_fwd_1d_fp16.cpp - conv/test_ckb_conv_fwd_1d_bf16.cpp - conv/test_ckb_conv_fwd_1d_i8.cpp - conv/test_ckb_conv_fwd_2d_fp8.cpp - conv/test_ckb_conv_fwd_2d_bf16.cpp - conv/test_ckb_conv_fwd_2d_fp16.cpp - conv/test_ckb_conv_fwd_2d_fp32.cpp - conv/test_ckb_conv_fwd_2d_dl_fp16.cpp - conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp - conv/test_ckb_conv_fwd_3d_bf16.cpp - conv/test_ckb_conv_fwd_3d_fp16.cpp - conv/test_ckb_conv_fwd_3d_fp32.cpp + conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp + conv/ck/test_ckb_conv_fwd_1d_fp16.cpp + conv/ck/test_ckb_conv_fwd_1d_bf16.cpp + conv/ck/test_ckb_conv_fwd_1d_i8.cpp + conv/ck/test_ckb_conv_fwd_2d_fp8.cpp + conv/ck/test_ckb_conv_fwd_2d_bf16.cpp + conv/ck/test_ckb_conv_fwd_2d_fp16.cpp + conv/ck/test_ckb_conv_fwd_2d_fp32.cpp + conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp + conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp + conv/ck/test_ckb_conv_fwd_3d_bf16.cpp + conv/ck/test_ckb_conv_fwd_3d_fp16.cpp + conv/ck/test_ckb_conv_fwd_3d_fp32.cpp + conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp + conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp + conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp ) diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp diff --git a/experimental/builder/test/conv/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp similarity index 100% rename from experimental/builder/test/conv/test_conv_traits.cpp rename to experimental/builder/test/conv/ck/test_conv_traits.cpp diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp new file mode 100644 index 0000000000..ad31fc52bc --- /dev/null +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_tile_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { + +using namespace ck_tile::builder::test_utils; + +TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +{ + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::BACKWARD_DATA, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(TileConvSpecialization::DEFAULT) + .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_optimizations(TileOptimizations{ + .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + + using Builder = ConvBuilder; + run_ck_tile_test({ + "grouped_convolution_backward_data", + "fp16", + "NHWGC_GKYXC_NHWGK", + "64x64x64", + "2x2", + "16x16x16", + // "4x4x4", // TODO: Enable this check + "Default", + "Intrawave", + "CShuffleEpilogue", + "set", + "pipeline_AgBgCrCompV3", + "DoubleSmemBuffer_0", + "NumWaveGroups_1", + "MergedGroups_1", + "SplitImage_0", + "ExplicitGemm_0", + }); +} + +} // namespace diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp new file mode 100644 index 0000000000..47908e0e5b --- /dev/null +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_tile_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { + +using namespace ck_tile::builder::test_utils; + +TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +{ + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::BACKWARD_WEIGHT, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(TileConvSpecialization::DEFAULT) + .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_optimizations(TileOptimizations{ + .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + + using Builder = ConvBuilder; + run_ck_tile_test({ + "grouped_convolution_backward_weight", + "fp16", + "NHWGC_GKYXC_NHWGK", + "64x64x64", + "2x2", + "16x16x16", + // "4x4x4", // TODO: Enable this check + "Default", + "Intrawave", + "CShuffleEpilogue", + "set", + "pipeline_AgBgCrCompV3", + "DoubleSmemBuffer_0", + "NumWaveGroups_1", + "MergedGroups_1", + "SplitImage_0", + "ExplicitGemm_0", + }); +} + +} // namespace diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp new file mode 100644 index 0000000000..083d9d9955 --- /dev/null +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_tile_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { + +using namespace ck_tile::builder::test_utils; + +TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +{ + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(TileConvSpecialization::DEFAULT) + .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_optimizations(TileOptimizations{ + .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + + using Builder = ConvBuilder; + run_ck_tile_test({ + "grouped_convolution_forward", + "fp16", + "NHWGC_GKYXC_NHWGK", + "64x64x64", + "2x2", + "16x16x16", + // "4x4x4", // TODO: Enable this check + "Default", + "Intrawave", + "CShuffleEpilogue", + "set", + "pipeline_AgBgCrCompV3", + "DoubleSmemBuffer_0", + "NumWaveGroups_1", + "MergedGroups_1", + "SplitImage_0", + "ExplicitGemm_0", + }); +} + +} // namespace diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index d89d83357f..29c7f3cdcc 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -243,6 +243,73 @@ struct LargeTensorWrapper ConvAlgorithmSpecialization::LARGE_TENSOR; }; +// Specify thread block dimensions for a GEMM (CK Tile). +struct TileThreadBlock +{ + // Size of the submatrix problem in a thread block. + MNK tile_size; +}; +static_assert(ckb::TileThreadBlockDescriptor); + +struct TileTransfer +{ + size_t a_scalar_per_vector; + size_t b_scalar_per_vector; + size_t c_scalar_per_vector; +}; +static_assert(ckb::TileTransferDescriptor); + +struct TileBlockGemm +{ + // Number of warps per each dimension. + MNK warps; + // Number of data processed per each dimension for each XDL/WMMA instruction. + MNK warp_tile; + // Double LDS buffer. + bool double_smem_buffer; + // Waves grouping (Ping-Pong scheduler). + int num_wave_groups; + PipelineVersion pipeline_version; + PipelineScheduler scheduler; +}; +static_assert(ckb::TileBlockGemmDescriptor); + +struct TileOptimizations +{ + // Number of convolution groups processed per one workgroup + int num_groups_to_merge; + // Split image for large tensors + bool split_image; + // Explicit gemm for 1x1, stride=0, pad=0 cases + bool explicit_gemm; +}; +static_assert(ckb::TileOptimizationsDescriptor); + +struct TileConvSpecialization_ +{ + TileConvSpecialization specialization; +}; + +struct TileThreadBlock_ +{ + TileThreadBlock thread_block; +}; + +struct TileTransfer_ +{ + TileTransfer transfer; +}; + +struct TileBlockGemm_ +{ + TileBlockGemm block_gemm; +}; + +struct TileOptimizations_ +{ + TileOptimizations optimizations; +}; + // Factory template @@ -339,6 +406,51 @@ struct ConvAlgorithmTemplate : Components... result.transfer = t; return result; } + + template + constexpr auto with_tile_specializations(const S& s) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.specialization = s; + return result; + } + + template + constexpr auto with_tile_thread_block(const TB& tb) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.thread_block = tb; + return result; + } + + template + constexpr auto with_tile_block_gemm(const BG& bg) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.block_gemm = bg; + return result; + } + + template + constexpr auto with_tile_transfer(const T& t) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.transfer = t; + return result; + } + + template + constexpr auto with_tile_optimizations(const O& o) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.optimizations = o; + return result; + } }; // Algorithm types @@ -361,4 +473,10 @@ using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = LargeTensorWrapper; +using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate; + } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/unit_conv_elementwise_op.cpp b/experimental/builder/test/unit_conv_elementwise_op.cpp index 84a9c533f6..610edd281e 100644 --- a/experimental/builder/test/unit_conv_elementwise_op.cpp +++ b/experimental/builder/test/unit_conv_elementwise_op.cpp @@ -4,7 +4,7 @@ #include #include -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_tensor_layout.cpp b/experimental/builder/test/unit_conv_tensor_layout.cpp index 7764e94dc6..26df33cc8d 100644 --- a/experimental/builder/test/unit_conv_tensor_layout.cpp +++ b/experimental/builder/test/unit_conv_tensor_layout.cpp @@ -4,7 +4,7 @@ #include #include -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" #include "impl/conv_signature_types.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_tensor_type.cpp b/experimental/builder/test/unit_conv_tensor_type.cpp index c92b24626e..7ffd446966 100644 --- a/experimental/builder/test/unit_conv_tensor_type.cpp +++ b/experimental/builder/test/unit_conv_tensor_type.cpp @@ -4,7 +4,7 @@ #include #include -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_thread_block.cpp b/experimental/builder/test/unit_conv_thread_block.cpp index f829708696..ce5a772cfa 100644 --- a/experimental/builder/test/unit_conv_thread_block.cpp +++ b/experimental/builder/test/unit_conv_thread_block.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT #include -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_tuning_params.cpp b/experimental/builder/test/unit_conv_tuning_params.cpp index 82117c53d8..b35a1ced55 100644 --- a/experimental/builder/test/unit_conv_tuning_params.cpp +++ b/experimental/builder/test/unit_conv_tuning_params.cpp @@ -3,7 +3,7 @@ #include -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" namespace { diff --git a/experimental/builder/test/utils/ckb_conv_test_utils.hpp b/experimental/builder/test/utils/ckb_conv_test_utils.hpp index 508c621c2e..1acf170455 100644 --- a/experimental/builder/test/utils/ckb_conv_test_utils.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_utils.hpp @@ -28,4 +28,20 @@ constexpr void run_test(const std::vector& kernel_instance_componen } } +// Common CK Tile test implementation +template +constexpr void run_ck_tile_test(const std::vector& kernel_instance_components) +{ + auto instance = typename Builder::Instance{}; + + const auto kernel_string = instance.GetTypeString(); + std::cout << "Generated kernel: " << kernel_string << std::endl; + EXPECT_GT(kernel_string.size(), 0); + std::cout << kernel_string << std::endl; + for(const auto& component : kernel_instance_components) + { + EXPECT_THAT(kernel_string, ::testing::HasSubstr(component)); + } +} + } // namespace ck_tile::builder::test_utils diff --git a/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp new file mode 100644 index 0000000000..377234dd19 --- /dev/null +++ b/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp @@ -0,0 +1,85 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "impl/conv_algorithm_types.hpp" +#include "impl/conv_signature_types.hpp" +#include "ck_tile/builder/conv_builder.hpp" + +namespace ck_tile::builder::test_utils { + +using namespace ck_tile::builder; +using namespace test; + +constexpr TileTransfer FwdTileTransfer_1x1x1{ + .a_scalar_per_vector = 1, + .b_scalar_per_vector = 1, + .c_scalar_per_vector = 1, +}; + +constexpr TileTransfer FwdTileTransfer_4x4x4{ + .a_scalar_per_vector = 4, + .b_scalar_per_vector = 4, + .c_scalar_per_vector = 4, +}; + +constexpr TileTransfer FwdTileTransfer_8x8x8{ + .a_scalar_per_vector = 8, + .b_scalar_per_vector = 8, + .c_scalar_per_vector = 8, +}; + +constexpr TileThreadBlock FwdTileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}}; + +constexpr TileThreadBlock FwdTileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v1_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V1, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v2_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V2, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v3_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V3, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v4_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V4, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v5_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V5, + .scheduler = PipelineScheduler::INTRAWAVE}; + +} // namespace ck_tile::builder::test_utils diff --git a/include/ck/config.h.in b/include/ck/config.h.in index 306a6c2ff1..113bf99243 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -55,6 +55,11 @@ #ifndef CK_ENABLE_FP32 #define CK_ENABLE_FP32 "ON" #endif +#ifndef CK_ENABLE_TF32 +#if defined(__gfx942__) || defined(__gfx95__) +#define CK_ENABLE_TF32 "ON" +#endif +#endif #ifndef CK_ENABLE_FP64 #define CK_ENABLE_FP64 "ON" #endif @@ -85,6 +90,12 @@ #cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@ #endif +#ifndef CK_ENABLE_TF32 +#if defined(__gfx942__) || defined(__gfx95__) +#cmakedefine CK_ENABLE_TF32 @CK_ENABLE_TF32@ +#endif +#endif + #ifndef CK_ENABLE_FP64 #cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@ #endif diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index d4475e8c60..8fae704203 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -176,8 +176,10 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); return concat('_', "pipeline_AgBgCrCompV3", concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize, + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), concat('x', WaveNumM, WaveNumN), - concat('x', kPadM, kPadN, kPadK)); + concat('x', kPadM, kPadN, kPadK), + Problem::GetName()); // clang-format on } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 2c6b1f3d48..e35f4ce70d 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -301,7 +301,12 @@ struct UniversalGemmPipelineProblem return concat('_', "gemm_problem", concat('x', kBlockSize), concat('x', kPadM, kPadN, kPadK), - Scheduler); + Scheduler, + "NumWaveGroups", + NumWaveGroups, + "DoubleSmemBuffer", + DoubleSmemBuffer + ); // clang-format on } }; diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index e172e732fa..46c60cb6d7 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -560,16 +560,31 @@ struct GroupedConvolutionBackwardDataKernel [[nodiscard]] CK_TILE_HOST static const std::string GetName() { + static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage; + constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; // clang-format off return concat('_', "grouped_convolution_backward_data", gemm_prec_str(), + InLayout::name, + WeiLayout::name, + OutLayout::name, "gemm", GemmPipeline::GetName(), "epilogue", - EpiloguePipeline::GetName()); + EpiloguePipeline::GetName(), + getConvSpecializationString(ConvSpecialization), + "MergedGroups", + NumGroupsToMerge, + "SplitImage", + EnableSplitImage, + "ExplicitGemm", + GroupedConvTraitsType_::ExplicitGemm + ); // clang-format on } + [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); } + #ifdef CK_EXPERIMENTAL_BUILDER CK_TILE_HOST std::string GetInstanceString() const { diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 6ef1d84a6e..f43bfdacac 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -417,26 +417,31 @@ struct GroupedConvolutionBackwardWeightKernel [[nodiscard]] CK_TILE_HOST static const std::string GetName() { - constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; + static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage; + constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; // clang-format off - if (NumGroupsToMerge > 1) { - return concat('_', "grouped_convolution_backward_weight", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName()); - } else { - return concat('_', "grouped_convolution_backward_weight", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName(), "merge", NumGroupsToMerge); - } + return concat('_', "grouped_convolution_backward_weight", + gemm_prec_str(), + InLayout::name, + WeiLayout::name, + OutLayout::name, + "gemm", + GemmPipeline::GetName(), + "epilogue", + EpiloguePipeline::GetName(), + getConvSpecializationString(ConvSpecialization), + "MergedGroups", + NumGroupsToMerge, + "SplitImage", + EnableSplitImage, + "ExplicitGemm", + GroupedConvTraitsType_::ExplicitGemm + ); // clang-format on } + [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); } + #ifdef CK_EXPERIMENTAL_BUILDER CK_TILE_HOST std::string GetInstanceString() const { diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 72ba17c5a5..a9f3274805 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -594,26 +594,28 @@ struct GroupedConvolutionForwardKernel { constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; // clang-format off - if (NumGroupsToMerge > 1) { - return concat('_', "grouped_convolution_forward", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName(), - "merge", - NumGroupsToMerge); - } else { - return concat('_', "grouped_convolution_forward", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName()); - } + return concat('_', "grouped_convolution_forward", + gemm_prec_str(), + InLayout::name, + WeiLayout::name, + OutLayout::name, + "gemm", + GemmPipeline::GetName(), + "epilogue", + EpiloguePipeline::GetName(), + getConvSpecializationString(ConvSpecialization), + "MergedGroups", + NumGroupsToMerge, + "SplitImage", + EnableSplitImage, + "ExplicitGemm", + GroupedConvTraitsType_::ExplicitGemm + ); // clang-format on } + [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); } + #ifdef CK_EXPERIMENTAL_BUILDER CK_TILE_HOST std::string GetInstanceString() const { diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 71739c9083..5b00e53af8 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -9,6 +9,13 @@ namespace ck_tile { +enum class GroupedConvDirection +{ + FORWARD, + BACKWARD_DATA, + BACKWARD_WEIGHT +}; + /// @brief The Grouped Conv kernel host arguments. /// /// @par Overview @@ -113,6 +120,36 @@ struct GroupedConvTraits using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor; using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor; + template + struct GemmLayouts + { + static_assert(false, "Unsupported direction."); + }; + + template <> + struct GemmLayouts + { + using AsLayout = AsLayoutFwd; + using BsLayout = BsLayoutFwd; + using CLayout = CLayoutFwd; + }; + + template <> + struct GemmLayouts + { + using AsLayout = AsLayoutBwdData; + using BsLayout = BsLayoutBwdData; + using CLayout = CLayoutBwdData; + }; + + template <> + struct GemmLayouts + { + using AsLayout = AsLayoutBwdWeight; + using BsLayout = BsLayoutBwdWeight; + using CLayout = CLayoutBwdWeight; + }; + template using GroupedConvImplicitGemmTraitsFwd = TileGemmTraits; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 03e3ae88a3..89009c6d0b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -115,12 +115,12 @@ struct DeviceOperationInstanceFactory< op_ptrs); } #endif -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances( @@ -130,7 +130,9 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_optimized_loads_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( op_ptrs); @@ -139,8 +141,8 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && @@ -284,12 +286,12 @@ struct DeviceOperationInstanceFactory< op_ptrs); } #endif -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( @@ -299,7 +301,9 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances( op_ptrs); } - else if constexpr(is_same_v) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( op_ptrs); @@ -308,8 +312,8 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_optimized_loads_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp index cd65a2285a..84a715b70a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp @@ -44,7 +44,9 @@ void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_in PassThrough, PassThrough, Bilinear>>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( std::vector && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { static_assert(is_same_v, "ComputeTypeA and ComputeTypeB must be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_instances( op_ptrs); } - else if constexpr(is_same_v) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp index 36980e5935..c898dbf781 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp @@ -44,7 +44,9 @@ void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_insta PassThrough, PassThrough, Scale>>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( std::vector && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { static_assert(is_same_v, " only support same compute type"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_instances( op_ptrs); } - else if constexpr(is_same_v) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_bf16_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index e677f6f848..3fe8fa9c5a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -347,12 +347,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: ComputeTypeA and ComputeTypeB should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( @@ -367,7 +367,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( op_ptrs); @@ -380,8 +382,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v && @@ -610,12 +612,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: ComputeTypeA and ComputeTypeB should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( @@ -629,7 +631,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); @@ -642,8 +646,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp index 448a6b5d51..a0e8e46570 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp @@ -62,6 +62,9 @@ void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_ PassThrough, Bilinear, PassThrough>>>& instances); +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp index acf9c9e150..64bbdf6ec5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp @@ -62,7 +62,9 @@ void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_ins PassThrough, Scale, PassThrough>>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index ba2f6b921a..5089ea2c1e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -198,12 +198,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same!"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); @@ -219,7 +219,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances( @@ -235,8 +237,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v && @@ -451,10 +453,10 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( @@ -472,7 +474,10 @@ struct DeviceOperationInstanceFactory && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( @@ -488,8 +493,8 @@ struct DeviceOperationInstanceFactory && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp index 46bc0d2320..d4729f4d13 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp @@ -129,12 +129,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "A and B compute types should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { @@ -153,7 +153,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( op_ptrs); @@ -170,8 +172,8 @@ struct DeviceOperationInstanceFactory && @@ -229,12 +231,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "A and B compute types should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { @@ -253,7 +255,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); @@ -270,8 +274,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( @@ -152,7 +152,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances( op_ptrs); @@ -169,9 +171,8 @@ struct DeviceOperationInstanceFactory && @@ -221,12 +222,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( @@ -244,7 +245,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); @@ -261,9 +264,8 @@ struct DeviceOperationInstanceFactory>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && DLayouts::Size() == 1 && is_same_v, NDHWGK>) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp index 90852d2945..090c99819f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp @@ -127,12 +127,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( @@ -150,7 +150,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances( op_ptrs); @@ -167,9 +169,8 @@ struct DeviceOperationInstanceFactory && @@ -218,12 +219,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( @@ -241,7 +242,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); @@ -258,8 +261,8 @@ struct DeviceOperationInstanceFactory>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && DLayouts::Size() == 0) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index eeaf269394..ef037526ca 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -13,6 +13,8 @@ function(add_instance_library INSTANCE_NAME) set(type1 "_f16") elseif(type MATCHES "fp32") set(type1 "_f32") + elseif(type MATCHES "tf32") + set(type1 "_tf32") elseif(type MATCHES "fp8") set(type1 "_f8") elseif(type MATCHES "bf16") @@ -27,8 +29,8 @@ function(add_instance_library INSTANCE_NAME) #if filename matches any selected type, exit type loop and do no exclude the file from the list set(test 0) break() - elseif((source_name MATCHES "fp8" OR source_name MATCHES "fp32" OR source_name MATCHES "fp64" OR source_name MATCHES "bf16" OR source_name MATCHES "int8" OR source_name MATCHES "fp16" OR - source_name MATCHES "_f8" OR source_name MATCHES "_f32" OR source_name MATCHES "_f64" OR source_name MATCHES "_i8" OR source_name MATCHES "_f16" OR source_name MATCHES "_b16") AND + elseif((source_name MATCHES "fp8" OR source_name MATCHES "fp32" OR source_name MATCHES "tf32" OR source_name MATCHES "fp64" OR source_name MATCHES "bf16" OR source_name MATCHES "int8" OR source_name MATCHES "fp16" OR + source_name MATCHES "_f8" OR source_name MATCHES "_f32" OR source_name MATCHES "_tf32" OR source_name MATCHES "_f64" OR source_name MATCHES "_i8" OR source_name MATCHES "_f16" OR source_name MATCHES "_b16") AND NOT (source_name MATCHES type OR source_name MATCHES type1)) #if filename contains a type which doesn't match any selected type, mark it for removal set(test 1) @@ -102,9 +104,11 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM ARGN "${source}") endif() # Only build tf32 instances for gfx942 & gfx950 - if(NOT (INST_TARGETS MATCHES "gfx942|gfx950") AND source_name MATCHES "_tf32_") - message(DEBUG "removing tf32 instance ${source} ") - list(REMOVE_ITEM ARGN "${source}") + if(source_name MATCHES "_tf32_") + if(NOT ((INST_TARGETS MATCHES "gfx942|gfx950") AND CK_ENABLE_TF32)) + message(DEBUG "removing tf32 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() endif() endforeach() @@ -223,6 +227,10 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "fp32 instance found!") set(add_inst 1) endif() + if(("${cmake_instance}" MATCHES "_tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32") + message(DEBUG "tf32 instance found!") + set(add_inst 1) + endif() if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64") message(DEBUG "fp64 instance found!") set(add_inst 1) @@ -237,6 +245,7 @@ FOREACH(subdir_path ${dir_list}) "${cmake_instance}" MATCHES "_f16" OR "${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32" OR + "${cmake_instance}" MATCHES "_tf32" OR "${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64" OR "${cmake_instance}" MATCHES "_bf16" OR @@ -330,7 +339,7 @@ FOREACH(subdir_path ${dir_list}) list(APPEND CK_DEVICE_OTHER_INSTANCES $) endif() message(DEBUG "add_instance_directory ${subdir_path}") - endif() + endif() else() message(DEBUG "skip_instance_directory ${subdir_path}") endif() diff --git a/profiler/src/profile_grouped_conv_bwd_data.cpp b/profiler/src/profile_grouped_conv_bwd_data.cpp index 62d6e860f9..cbf763fc13 100644 --- a/profiler/src/profile_grouped_conv_bwd_data.cpp +++ b/profiler/src/profile_grouped_conv_bwd_data.cpp @@ -84,9 +84,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; using BF16 = ck::bhalf_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using namespace ck::tensor_layout::convolution; @@ -143,9 +141,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -164,9 +160,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -185,9 +179,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKCYX_NGKHW) @@ -206,9 +198,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } } @@ -230,9 +220,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -251,9 +239,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -272,9 +258,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -293,9 +277,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } } diff --git a/profiler/src/profile_grouped_conv_bwd_weight.cpp b/profiler/src/profile_grouped_conv_bwd_weight.cpp index a18aab41a5..c4f154e180 100644 --- a/profiler/src/profile_grouped_conv_bwd_weight.cpp +++ b/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -99,9 +99,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) using BF16 = ck::bhalf_t; using F8 = ck::f8_t; using BF8 = ck::bf8_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using namespace ck::tensor_layout::convolution; @@ -162,9 +160,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -184,9 +180,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -210,9 +204,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -243,9 +235,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -270,9 +260,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -306,9 +294,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -340,9 +326,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index c94b77dd4f..4319d849c8 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -105,9 +105,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using INT8 = int8_t; using F8 = ck::f8_t; using BF8 = ck::bf8_t; -#if defined(__gfx942__) || defined(__gfx950__) using TF32 = ck::tf32_t; -#endif // using GNWC = ck::tensor_layout::convolution::GNWC; @@ -228,9 +226,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -253,9 +249,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -280,9 +274,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } // NHWGC_GKYXC_NHWGK @@ -306,9 +298,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -331,9 +321,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -352,9 +340,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW) @@ -373,9 +359,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -416,9 +400,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } // NGCDHW_GKCZYX_NGKDHW @@ -439,9 +421,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp index 4eb12e6e19..79b9beb8c7 100644 --- a/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp +++ b/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp @@ -105,9 +105,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) using F32 = float; using BF16 = ck::bhalf_t; using F16 = ck::half_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using GKZYXC = ck::tensor_layout::convolution::GKZYXC; using NDHWGC = ck::tensor_layout::convolution::NDHWGC; @@ -172,9 +170,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -194,9 +190,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/profiler/src/profile_grouped_conv_fwd_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_clamp.cpp index 7df9fd6167..f497ee8da5 100644 --- a/profiler/src/profile_grouped_conv_fwd_clamp.cpp +++ b/profiler/src/profile_grouped_conv_fwd_clamp.cpp @@ -105,9 +105,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[]) using F32 = float; using BF16 = ck::bhalf_t; using F16 = ck::half_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using GKZYXC = ck::tensor_layout::convolution::GKZYXC; using NDHWGC = ck::tensor_layout::convolution::NDHWGC; @@ -175,9 +173,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -197,9 +193,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f8498c6c03..c221f11f46 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -65,6 +65,9 @@ function(add_test_executable TEST_NAME) if((source_name MATCHES "_fp32|_f32") AND NOT "fp32" IN_LIST DTYPES) set(test 1) endif() + if((source_name MATCHES "_tf32|_tf32") AND NOT "tf32" IN_LIST DTYPES) + set(test 1) + endif() if((source_name MATCHES "_fp64|_f64") AND NOT "fp64" IN_LIST DTYPES) set(test 1) endif() @@ -156,6 +159,9 @@ function(add_gtest_executable TEST_NAME) if((source_name MATCHES "_fp32|_f32") AND NOT "fp32" IN_LIST DTYPES) set(test 1) endif() + if((source_name MATCHES "_tf32|_tf32") AND NOT "tf32" IN_LIST DTYPES) + set(test 1) + endif() if((source_name MATCHES "_fp64|_f64") AND NOT "fp64" IN_LIST DTYPES) set(test 1) endif()