From ea6bc81dd1dfdfe264a68bd353c88159bce0b5bd Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Thu, 20 Nov 2025 19:40:48 +0100 Subject: [PATCH] ck-builder: group transfer operations per tensor (#3217) Grouping transfer operations per tensor makes it easier to constrain on and operate with the transfer operations. As an example, we can now deduplicate the logic for translating the transfer operations from the ck-builder interface to the old ck interface for the A and B tensors. [ROCm/composable_kernel commit: 245c6011cfe466233427e307b13be2bd1f114f7f] --- .../builder/conv_algorithm_concepts.hpp | 26 +- .../include/ck_tile/builder/conv_factory.hpp | 66 ++--- experimental/builder/test/CMakeLists.txt | 8 +- .../test/conv/test_ckb_conv_fwd_1d_bf16.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_1d_fp16.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_1d_i8.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_2d_bf16.cpp | 4 +- .../conv/test_ckb_conv_fwd_2d_dl_fp16.cpp | 6 +- .../test/conv/test_ckb_conv_fwd_2d_fp16.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_2d_fp32.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_2d_fp8.cpp | 2 +- ...test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 4 +- .../test/conv/test_ckb_conv_fwd_3d_bf16.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_3d_fp16.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_3d_fp32.cpp | 2 +- .../test/impl/conv_algorithm_types.hpp | 92 ++++--- .../builder/test/test_conv_description.cpp | 57 +++-- .../test/utils/ckb_conv_test_configs.hpp | 241 +++++++++++------- 18 files changed, 280 insertions(+), 242 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index ea67d5ccc2..7168c4d883 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -125,31 +125,31 @@ concept SpecifiesGridwiseWmmaGemm = requires { // Concept to check if a struct specifies convolution input and output block transfer info. template concept SpecifiesBlockTransfer = requires(T t) { - { T::block_transfer.block_transfer_a } -> BlockTransferDescriptor; - { T::block_transfer.block_transfer_b } -> BlockTransferDescriptor; - { T::block_transfer.thread_cluster_dims_c } -> ThreadClusterDescriptor; + { T::transfer.a.block_transfer } -> BlockTransferDescriptor; + { T::transfer.b.block_transfer } -> BlockTransferDescriptor; + { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; // Concept to check if a struct specifies LDS transfer info for tensors A, B, and C. template concept SpecifiesLdsTransfer = requires(T t) { - { T::block_transfer.lds_transfer_a } -> LdsTransferDescriptor; - { T::block_transfer.lds_transfer_b } -> LdsTransferDescriptor; - { T::block_transfer.epilogue_c } -> EpilogueDescriptor; + { T::transfer.a.lds_transfer } -> LdsTransferDescriptor; + { T::transfer.b.lds_transfer } -> LdsTransferDescriptor; + { T::transfer.c.epilogue } -> EpilogueDescriptor; }; // Concept to check if a struct specifies thread cluster access order info. template concept SpecifiesThreadClusterAccessOrder = requires(T t) { - { T::block_transfer.block_transfer_access_order_a } -> AccessOrderDescriptor; - { T::block_transfer.block_transfer_access_order_b } -> AccessOrderDescriptor; + { T::transfer.a.block_transfer_access_order } -> AccessOrderDescriptor; + { T::transfer.b.block_transfer_access_order } -> AccessOrderDescriptor; }; // Concept to check if a struct specifies source access order info. template concept SpecifiesSourceAccessOrder = requires(T t) { - { T::block_transfer.src_access_order_a } -> AccessOrderDescriptor; - { T::block_transfer.src_access_order_b } -> AccessOrderDescriptor; + { T::transfer.a.src_access_order } -> AccessOrderDescriptor; + { T::transfer.b.src_access_order } -> AccessOrderDescriptor; }; // Concept to check if struct specifies block GEMM. @@ -246,14 +246,14 @@ concept SpecifiesDlThreadCluster = requires { // Concept to check if algorithm specifies DL block transfer template concept SpecifiesDlBlockTransfer = requires { - { T::block_transfer_a } -> DlBlockTransferDescriptor; - { T::block_transfer_b } -> DlBlockTransferDescriptor; + { T::transfer.a.block_transfer } -> DlBlockTransferDescriptor; + { T::transfer.b.block_transfer } -> DlBlockTransferDescriptor; }; // Concept to check if algorithm specifies DL C thread transfer template concept SpecifiesDlEpilogue = requires { - { T::epilogue_c } -> DlEpilogueDescriptor; + { T::transfer.c.epilogue } -> DlEpilogueDescriptor; }; /******************************************** */ diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 39260c8acd..248a37d51b 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -25,8 +25,7 @@ // `constexpr` Helper Functions: // - SetThreadBlockInfo: Determines thread block dimensions and tile sizes. // - SetConvTuningInfo: Sets XDL and AK1/BK1 tuning parameters. -// - SetFwdConvABlockTransfer: Configures A tensor block transfer parameters. -// - SetFwdConvBBlockTransfer: Configures B tensor block transfer parameters. +// - SetFwdConvBlockTransfer: Configures A/B tensor block transfer parameters. // - SetCBlockTransfer: Configures C tensor block transfer parameters. // - SetBlockGemmPipelineVersion: Maps pipeline version enum to CK types. // @@ -381,32 +380,13 @@ struct BlockTransfer bool lds_padding = false; }; -template -constexpr BlockTransfer SetFwdConvABlockTransfer() +template +constexpr BlockTransfer SetFwdConvBlockTransfer() { - constexpr auto& TCL = ALGORITHM.block_transfer.block_transfer_a; - constexpr auto& TCO = ALGORITHM.block_transfer.block_transfer_access_order_a; - constexpr auto& SAO = ALGORITHM.block_transfer.src_access_order_a; - constexpr auto& LDS = ALGORITHM.block_transfer.lds_transfer_a; - - BlockTransfer block_transfer{.thread_cluster_dims = {TCL.k0, TCL.m_n, TCL.k1}, - .thread_cluster_order = {TCO.order[0], TCO.order[1], TCO.order[2]}, - .src_access_order = {SAO.order[0], SAO.order[1], SAO.order[2]}, - .src_vector_dim = LDS.src_vector_dim, - .src_scalar_per_vector = LDS.src_scalar_per_vector, - .lds_dst_scalar_per_vector = LDS.lds_dst_scalar_per_vector, - .is_direct_load = LDS.is_direct_load, - .lds_padding = LDS.lds_padding}; - return block_transfer; -} - -template -constexpr BlockTransfer SetFwdConvBBlockTransfer() -{ - constexpr auto& TCL = ALGORITHM.block_transfer.block_transfer_b; - constexpr auto& TCO = ALGORITHM.block_transfer.block_transfer_access_order_b; - constexpr auto& SAO = ALGORITHM.block_transfer.src_access_order_b; - constexpr auto& LDS = ALGORITHM.block_transfer.lds_transfer_b; + constexpr auto& TCL = TRANSFER.block_transfer; + constexpr auto& TCO = TRANSFER.block_transfer_access_order; + constexpr auto& SAO = TRANSFER.src_access_order; + constexpr auto& LDS = TRANSFER.lds_transfer; BlockTransfer block_transfer{.thread_cluster_dims = {TCL.k0, TCL.m_n, TCL.k1}, .thread_cluster_order = {TCO.order[0], TCO.order[1], TCO.order[2]}, @@ -431,8 +411,8 @@ struct CBlockTransfer template constexpr CBlockTransfer SetCBlockTransfer() { - constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_dims_c; - constexpr auto& EPC = ALGORITHM.block_transfer.epilogue_c; + constexpr auto& TCL = ALGORITHM.transfer.c.thread_cluster_dims; + constexpr auto& EPC = ALGORITHM.transfer.c.epilogue; CBlockTransfer block_transfer{.m_per_wave_per_shuffle = EPC.m_per_wave_per_shuffle, .n_per_wave_per_shuffle = EPC.n_per_wave_per_shuffle, .thread_cluster_dims = @@ -568,11 +548,11 @@ struct ConvFactory using Ops = factory_internal::ElementwiseOps()>; using AlgorithmType = decltype(ALGORITHM); - static_assert(ALGORITHM.block_transfer.lds_transfer_a.is_direct_load == - ALGORITHM.block_transfer.lds_transfer_b.is_direct_load, + static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load == + ALGORITHM.transfer.b.lds_transfer.is_direct_load, "A and B block transfers must both be direct load or not."); - static constexpr bool IS_DIRECT_LOAD = ALGORITHM.block_transfer.lds_transfer_a.is_direct_load; + static constexpr bool IS_DIRECT_LOAD = ALGORITHM.transfer.a.lds_transfer.is_direct_load; static constexpr auto FWD_CONV_SPECIALIZATION = factory_internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = @@ -583,9 +563,9 @@ struct ConvFactory static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; static constexpr auto A_BLOCK_TRANSFER = - factory_internal::SetFwdConvABlockTransfer(); + factory_internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = - factory_internal::SetFwdConvBBlockTransfer(); + factory_internal::SetFwdConvBlockTransfer(); static constexpr auto C_BLOCK_TRANSFER = factory_internal::SetCBlockTransfer(); static constexpr auto BLOCK_GEMM = factory_internal::SetBlockGemm(); @@ -681,9 +661,9 @@ struct ConvFactory static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; static constexpr auto A_BLOCK_TRANSFER = - factory_internal::SetFwdConvABlockTransfer(); + factory_internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = - factory_internal::SetFwdConvBBlockTransfer(); + factory_internal::SetFwdConvBlockTransfer(); static constexpr auto C_BLOCK_TRANSFER = factory_internal::SetCBlockTransfer(); @@ -780,9 +760,9 @@ struct ConvFactory static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = factory_internal::SetGridwiseGemmPipelineVersion(); static constexpr auto A_BLOCK_TRANSFER = - factory_internal::SetFwdConvABlockTransfer(); + factory_internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = - factory_internal::SetFwdConvBBlockTransfer(); + factory_internal::SetFwdConvBlockTransfer(); static constexpr auto C_BLOCK_TRANSFER = factory_internal::SetCBlockTransfer(); @@ -884,7 +864,7 @@ struct ConvFactory using M1N1ThreadClusterN1Xs = to_sequence_v; // A Block Transfer from descriptor - K0_M0_M1_K1 tensor format - static constexpr auto DL_A_TRANSFER = ALGORITHM.block_transfer_a; + static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a.block_transfer; using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = to_sequence_v; using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = @@ -900,7 +880,7 @@ struct ConvFactory to_sequence_v; // B Block Transfer from descriptor - K0_N0_N1_K1 tensor format - static constexpr auto DL_B_TRANSFER = ALGORITHM.block_transfer_b; + static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b.block_transfer; using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = to_sequence_v; using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = @@ -916,7 +896,7 @@ struct ConvFactory to_sequence_v; // C Thread Transfer from descriptor - static constexpr auto DL_C_TRANSFER = ALGORITHM.epilogue_c; + static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c.epilogue; using CThreadTransferSrcDstAccessOrder = to_sequence_v; static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim; static constexpr ck::index_t CThreadTransferDstScalarPerVector = @@ -998,9 +978,9 @@ struct ConvFactory static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); static constexpr auto GRIDWISE_GEMM = BASE_ALGORITHM.gridwise_gemm; static constexpr auto A_BLOCK_TRANSFER = - factory_internal::SetFwdConvABlockTransfer(); + factory_internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = - factory_internal::SetFwdConvBBlockTransfer(); + factory_internal::SetFwdConvBlockTransfer(); static constexpr auto C_BLOCK_TRANSFER = factory_internal::SetCBlockTransfer(); diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 5044d223de..1bac03e050 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -39,8 +39,8 @@ add_ck_builder_test(test_ckb_get_instance_string add_ck_builder_test(test_ckb_build_fwd_instances conv/test_ckb_conv_fwd_1d_fp16.cpp conv/test_ckb_conv_fwd_1d_bf16.cpp - conv/test_ckb_conv_fwd_1d_i8.cpp - conv/test_ckb_conv_fwd_2d_fp8.cpp + conv/test_ckb_conv_fwd_1d_i8.cpp + conv/test_ckb_conv_fwd_2d_fp8.cpp conv/test_ckb_conv_fwd_2d_bf16.cpp conv/test_ckb_conv_fwd_2d_fp16.cpp conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -78,7 +78,7 @@ add_ck_builder_test(test_ckb_conv_description function(collect_test_ckb_targets result_var) # Get all targets in current directory get_directory_property(all_targets BUILDSYSTEM_TARGETS) - + set(test_ckb_targets) foreach(target ${all_targets}) # Check if target name starts with "test_ckb" @@ -87,7 +87,7 @@ function(collect_test_ckb_targets result_var) list(APPEND test_ckb_targets ${target}) endif() endforeach() - + set(${result_var} ${test_ckb_targets} PARENT_SCOPE) endfunction() diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp index bb0c767bbd..e37d77e202 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} .with_thread_block(FwdThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_transfer(FwdTransfer_4x64x1) .with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v2_intrawave); diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp index d391c1a74d..8da5d7c0b0 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} .with_thread_block(FwdThreadBlock_64_64x32x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_block_transfer(FwdBlockTransfer_4x16x1) + .with_transfer(FwdTransfer_4x16x1) .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 2, PipelineScheduler::DEFAULT); diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp index 7206c768d8..c13efff952 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{} .with_thread_block(FwdThreadBlock_128_64x64x64) .with_gemm_config(FwdGemmParams_Wmma_2x1_per_wave) - .with_block_transfer(FwdBlockTransfer_4x32x1) + .with_transfer(FwdTransfer_4x32x1) .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 0, PipelineScheduler::DEFAULT); diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp index d0bc1d7a6d..b65e3eb338 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} .with_thread_block(FwdThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_transfer(FwdTransfer_4x64x1) .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v1_intrawave); @@ -50,7 +50,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} .with_thread_block(FwdThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_transfer(FwdTransfer_4x64x1) .with_specializations(ConvFwdSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v5_intrawave); diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp index 0a337f3a7b..d06b3e1f90 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp @@ -23,8 +23,7 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_dl_thread_config(DlThreadConfig_16x2x4x4x1) .with_dl_thread_cluster(DlThreadCluster_8x2) - .with_dl_block_transfer(DlBlockTransferAB, DlBlockTransferAB) - .with_dl_epilogue(DlEpilogueC); + .with_dl_transfer(DlFwdTransfer); using Builder = ConvBuilder; run_test( @@ -48,8 +47,7 @@ TEST(FwdConvInstances, GemmSpecialization::MNKPadding) .with_dl_thread_config(DlThreadConfig_16x2x4x4x1) .with_dl_thread_cluster(DlThreadCluster_8x2) - .with_dl_block_transfer(DlBlockTransferAB, DlBlockTransferAB) - .with_dl_epilogue(DlEpilogueC); + .with_dl_transfer(DlFwdTransfer); using Builder = ConvBuilder; run_test( diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp index 798c6cfa2d..470a58fb2a 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp @@ -22,7 +22,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} .with_thread_block(FwdThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_transfer(FwdTransfer_4x64x1) .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v3_intrawave); diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp index a8313ff510..554b8565ae 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -22,7 +22,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} .with_thread_block(FwdThreadBlock_256_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_transfer(FwdTransfer_4x64x1) .with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v4_intrawave); diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp index 39319bb79e..c26338b579 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} .with_thread_block(FwdThreadBlock_256_256x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave) - .with_block_transfer(FwdBlockTransfer_4x64x1_fp8) + .with_transfer(FwdTransfer_4x64x1_fp8) .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp index 6c43678bf1..73537dae0c 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, .base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} .with_thread_block(FwdThreadBlock_256_256x128x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_block_transfer(FwdBlockTransfer_4x16x1) + .with_transfer(FwdTransfer_4x16x1) .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; @@ -50,7 +50,7 @@ TEST( .base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} .with_thread_block(FwdThreadBlock_128_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_block_transfer(FwdBlockTransfer_4x16x1) + .with_transfer(FwdTransfer_4x16x1) .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp index 2392d1efff..9140a1ab51 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} .with_thread_block(FwdThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_transfer(FwdTransfer_4x64x1) .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v3_intrawave); diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp index 52e153098e..42fa8d8be1 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} .with_thread_block(FwdThreadBlock_256_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_transfer(FwdTransfer_4x64x1) .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v4_intrawave); diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp index 5d1656924c..c630f52c23 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} .with_thread_block(FwdThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_transfer(FwdTransfer_4x64x1) .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v1_intrawave); diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index fd756cf06e..e7c499d9b9 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -103,18 +103,25 @@ struct AccessOrder }; static_assert(AccessOrderDescriptor); -struct BlockTransferABC +struct TransferAB { - BlockTransfer block_transfer_a; - BlockTransfer block_transfer_b; - ThreadCluster thread_cluster_dims_c; - LdsTransfer lds_transfer_a; - LdsTransfer lds_transfer_b; - Epilogue epilogue_c; - AccessOrder block_transfer_access_order_a; - AccessOrder block_transfer_access_order_b; - AccessOrder src_access_order_a; - AccessOrder src_access_order_b; + BlockTransfer block_transfer; + LdsTransfer lds_transfer; + AccessOrder block_transfer_access_order; + AccessOrder src_access_order; +}; + +struct TransferC +{ + ThreadCluster thread_cluster_dims; + Epilogue epilogue; +}; + +struct TransferABC +{ + TransferAB a; + TransferAB b; + TransferC c; }; // DL-specific descriptors @@ -172,9 +179,9 @@ struct WmmaGemm_ GridwiseWmmaGemm gridwise_gemm; }; -struct BlockTransfer_ +struct Transfer_ { - BlockTransferABC block_transfer; + TransferABC transfer; }; struct ConvSpecialization_ @@ -205,15 +212,26 @@ struct DlThreadCluster_ DlThreadCluster thread_cluster; }; -struct DlBlockTransfer_ +struct DlBlockTransferAB { - DlBlockTransfer block_transfer_a; - DlBlockTransfer block_transfer_b; + DlBlockTransfer block_transfer; }; -struct DlEpilogue_ +struct DlBlockTransferC { - DlEpilogue epilogue_c; + DlEpilogue epilogue; +}; + +struct DlTransferABC +{ + DlBlockTransferAB a; + DlBlockTransferAB b; + DlBlockTransferC c; +}; + +struct DlTransfer_ +{ + DlTransferABC transfer; }; // Specialization wrapper for large tensor support @@ -255,12 +273,12 @@ struct ConvAlgorithmTemplate : Components... return result; } - template - constexpr auto with_block_transfer(const BT& bt) const + template + constexpr auto with_transfer(const T& t) const { - static_assert(std::is_base_of_v); - auto result = *this; - result.block_transfer = bt; + static_assert(std::is_base_of_v); + auto result = *this; + result.transfer = t; return result; } @@ -313,21 +331,12 @@ struct ConvAlgorithmTemplate : Components... return result; } - template - constexpr auto with_dl_block_transfer(const BTA& bta, const BTB& btb) const + template + constexpr auto with_dl_transfer(const T& t) const { - static_assert(std::is_base_of_v); - auto result = *this; - result.block_transfer_a = bta; - result.block_transfer_b = btb; - return result; - } - - constexpr auto with_dl_epilogue(const DlEpilogue& epi) const - { - static_assert(std::is_base_of_v); - auto result = *this; - result.epilogue_c = epi; + static_assert(std::is_base_of_v); + auto result = *this; + result.transfer = t; return result; } }; @@ -335,20 +344,19 @@ struct ConvAlgorithmTemplate : Components... // Algorithm types using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate; using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate; using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = ConvAlgorithmTemplate; + DlTransfer_>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = LargeTensorWrapper; diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index c2f7039348..941fac5389 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -64,30 +64,39 @@ struct DefaultAlgorithm .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}; - ckb::test::BlockTransferABC block_transfer{ - .block_transfer_a = {.k0 = 4, .m_n = 256, .k1 = 8}, - .block_transfer_b = {.k0 = 4, .m_n = 256, .k1 = 8}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 32, - .n_block = 1, - .n_wave_per_xdl = 8}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = true, - .lds_padding = false}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = true, - .lds_padding = false}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {.order = {0, 1, 2}}, - .block_transfer_access_order_b = {.order = {0, 1, 2}}, - .src_access_order_a = {.order = {0, 1, 2}}, - .src_access_order_b = {.order = {0, 1, 2}}}; + ckb::test::TransferABC transfer{ + .a = + { + .block_transfer = {.k0 = 4, .m_n = 256, .k1 = 8}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = true, + .lds_padding = false}, + .block_transfer_access_order = {.order = {0, 1, 2}}, + .src_access_order = {.order = {0, 1, 2}}, + + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 256, .k1 = 8}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = true, + .lds_padding = false}, + .block_transfer_access_order = {.order = {0, 1, 2}}, + .src_access_order = {.order = {0, 1, 2}}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, + .epilogue = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + }, + }; ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT; ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 7f2acce9c8..55d21e0fd2 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -25,109 +25,152 @@ constexpr DlBlockTransfer DlBlockTransferAB{.thread_slice_lengths = {8, .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, .dst_vector_tensor_lengths = {1, 1, 1, 2}}; -constexpr DlEpilogue DlEpilogueC{.src_dst_access_order = {0, 1, 2, 3, 4, 5}, - .src_dst_vector_dim = 5, - .dst_scalar_per_vector = 4}; +constexpr DlTransferABC DlFwdTransfer{.a = + { + .block_transfer = DlBlockTransferAB, + }, + .b = + { + .block_transfer = DlBlockTransferAB, + }, + .c = { + .epilogue = {.src_dst_access_order = {0, 1, 2, 3, 4, 5}, + .src_dst_vector_dim = 5, + .dst_scalar_per_vector = 4}, + }}; -constexpr BlockTransferABC FwdBlockTransfer_4x64x1{ - .block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1}, - .block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 32, - .n_block = 1, - .n_wave_per_xdl = 8}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = false}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = false}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {1, 0, 2}, - .block_transfer_access_order_b = {1, 0, 2}, - .src_access_order_a = {1, 0, 2}, - .src_access_order_b = {1, 0, 2}}; +constexpr TransferABC FwdTransfer_4x64x1{ + .a = + { + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = false}, + .block_transfer_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = false}, + .block_transfer_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, + .epilogue = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + }, +}; -constexpr BlockTransferABC FwdBlockTransfer_4x64x1_fp8{ - .block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1}, - .block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 32, - .n_block = 1, - .n_wave_per_xdl = 8}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {1, 0, 2}, - .block_transfer_access_order_b = {1, 0, 2}, - .src_access_order_a = {1, 0, 2}, - .src_access_order_b = {1, 0, 2}}; +constexpr TransferABC FwdTransfer_4x64x1_fp8{ + .a = + { + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .block_transfer_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .block_transfer_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, + .epilogue = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + }, +}; -constexpr BlockTransferABC FwdBlockTransfer_4x16x1{ - .block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1}, - .block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 16, - .n_block = 1, - .n_wave_per_xdl = 4}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {1, 0, 2}, - .block_transfer_access_order_b = {1, 0, 2}, - .src_access_order_a = {1, 0, 2}, - .src_access_order_b = {1, 0, 2}}; +constexpr TransferABC FwdTransfer_4x16x1{ + .a = + { + .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .block_transfer_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .block_transfer_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 16, .n_block = 1, .n_wave_per_xdl = 4}, + .epilogue = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, -constexpr BlockTransferABC FwdBlockTransfer_4x32x1{ - .block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1}, - .block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 32, - .n_block = 1, - .n_wave_per_xdl = 4}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 16, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 16, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {1, 0, 2}, - .block_transfer_access_order_b = {1, 0, 2}, - .src_access_order_a = {1, 0, 2}, - .src_access_order_b = {1, 0, 2}}; + }, +}; + +constexpr TransferABC FwdTransfer_4x32x1{ + .a = + { + .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .block_transfer_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .block_transfer_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 4}, + .epilogue = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + }, +}; constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4};