diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index ea67d5ccc2..7168c4d883 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -125,31 +125,31 @@ concept SpecifiesGridwiseWmmaGemm = requires { // Concept to check if a struct specifies convolution input and output block transfer info. template concept SpecifiesBlockTransfer = requires(T t) { - { T::block_transfer.block_transfer_a } -> BlockTransferDescriptor; - { T::block_transfer.block_transfer_b } -> BlockTransferDescriptor; - { T::block_transfer.thread_cluster_dims_c } -> ThreadClusterDescriptor; + { T::transfer.a.block_transfer } -> BlockTransferDescriptor; + { T::transfer.b.block_transfer } -> BlockTransferDescriptor; + { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; // Concept to check if a struct specifies LDS transfer info for tensors A, B, and C. template concept SpecifiesLdsTransfer = requires(T t) { - { T::block_transfer.lds_transfer_a } -> LdsTransferDescriptor; - { T::block_transfer.lds_transfer_b } -> LdsTransferDescriptor; - { T::block_transfer.epilogue_c } -> EpilogueDescriptor; + { T::transfer.a.lds_transfer } -> LdsTransferDescriptor; + { T::transfer.b.lds_transfer } -> LdsTransferDescriptor; + { T::transfer.c.epilogue } -> EpilogueDescriptor; }; // Concept to check if a struct specifies thread cluster access order info. template concept SpecifiesThreadClusterAccessOrder = requires(T t) { - { T::block_transfer.block_transfer_access_order_a } -> AccessOrderDescriptor; - { T::block_transfer.block_transfer_access_order_b } -> AccessOrderDescriptor; + { T::transfer.a.block_transfer_access_order } -> AccessOrderDescriptor; + { T::transfer.b.block_transfer_access_order } -> AccessOrderDescriptor; }; // Concept to check if a struct specifies source access order info. template concept SpecifiesSourceAccessOrder = requires(T t) { - { T::block_transfer.src_access_order_a } -> AccessOrderDescriptor; - { T::block_transfer.src_access_order_b } -> AccessOrderDescriptor; + { T::transfer.a.src_access_order } -> AccessOrderDescriptor; + { T::transfer.b.src_access_order } -> AccessOrderDescriptor; }; // Concept to check if struct specifies block GEMM. @@ -246,14 +246,14 @@ concept SpecifiesDlThreadCluster = requires { // Concept to check if algorithm specifies DL block transfer template concept SpecifiesDlBlockTransfer = requires { - { T::block_transfer_a } -> DlBlockTransferDescriptor; - { T::block_transfer_b } -> DlBlockTransferDescriptor; + { T::transfer.a.block_transfer } -> DlBlockTransferDescriptor; + { T::transfer.b.block_transfer } -> DlBlockTransferDescriptor; }; // Concept to check if algorithm specifies DL C thread transfer template concept SpecifiesDlEpilogue = requires { - { T::epilogue_c } -> DlEpilogueDescriptor; + { T::transfer.c.epilogue } -> DlEpilogueDescriptor; }; /******************************************** */ diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 39260c8acd..248a37d51b 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -25,8 +25,7 @@ // `constexpr` Helper Functions: // - SetThreadBlockInfo: Determines thread block dimensions and tile sizes. // - SetConvTuningInfo: Sets XDL and AK1/BK1 tuning parameters. -// - SetFwdConvABlockTransfer: Configures A tensor block transfer parameters. -// - SetFwdConvBBlockTransfer: Configures B tensor block transfer parameters. +// - SetFwdConvBlockTransfer: Configures A/B tensor block transfer parameters. // - SetCBlockTransfer: Configures C tensor block transfer parameters. // - SetBlockGemmPipelineVersion: Maps pipeline version enum to CK types. // @@ -381,32 +380,13 @@ struct BlockTransfer bool lds_padding = false; }; -template -constexpr BlockTransfer SetFwdConvABlockTransfer() +template +constexpr BlockTransfer SetFwdConvBlockTransfer() { - constexpr auto& TCL = ALGORITHM.block_transfer.block_transfer_a; - constexpr auto& TCO = ALGORITHM.block_transfer.block_transfer_access_order_a; - constexpr auto& SAO = ALGORITHM.block_transfer.src_access_order_a; - constexpr auto& LDS = ALGORITHM.block_transfer.lds_transfer_a; - - BlockTransfer block_transfer{.thread_cluster_dims = {TCL.k0, TCL.m_n, TCL.k1}, - .thread_cluster_order = {TCO.order[0], TCO.order[1], TCO.order[2]}, - .src_access_order = {SAO.order[0], SAO.order[1], SAO.order[2]}, - .src_vector_dim = LDS.src_vector_dim, - .src_scalar_per_vector = LDS.src_scalar_per_vector, - .lds_dst_scalar_per_vector = LDS.lds_dst_scalar_per_vector, - .is_direct_load = LDS.is_direct_load, - .lds_padding = LDS.lds_padding}; - return block_transfer; -} - -template -constexpr BlockTransfer SetFwdConvBBlockTransfer() -{ - constexpr auto& TCL = ALGORITHM.block_transfer.block_transfer_b; - constexpr auto& TCO = ALGORITHM.block_transfer.block_transfer_access_order_b; - constexpr auto& SAO = ALGORITHM.block_transfer.src_access_order_b; - constexpr auto& LDS = ALGORITHM.block_transfer.lds_transfer_b; + constexpr auto& TCL = TRANSFER.block_transfer; + constexpr auto& TCO = TRANSFER.block_transfer_access_order; + constexpr auto& SAO = TRANSFER.src_access_order; + constexpr auto& LDS = TRANSFER.lds_transfer; BlockTransfer block_transfer{.thread_cluster_dims = {TCL.k0, TCL.m_n, TCL.k1}, .thread_cluster_order = {TCO.order[0], TCO.order[1], TCO.order[2]}, @@ -431,8 +411,8 @@ struct CBlockTransfer template constexpr CBlockTransfer SetCBlockTransfer() { - constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_dims_c; - constexpr auto& EPC = ALGORITHM.block_transfer.epilogue_c; + constexpr auto& TCL = ALGORITHM.transfer.c.thread_cluster_dims; + constexpr auto& EPC = ALGORITHM.transfer.c.epilogue; CBlockTransfer block_transfer{.m_per_wave_per_shuffle = EPC.m_per_wave_per_shuffle, .n_per_wave_per_shuffle = EPC.n_per_wave_per_shuffle, .thread_cluster_dims = @@ -568,11 +548,11 @@ struct ConvFactory using Ops = factory_internal::ElementwiseOps()>; using AlgorithmType = decltype(ALGORITHM); - static_assert(ALGORITHM.block_transfer.lds_transfer_a.is_direct_load == - ALGORITHM.block_transfer.lds_transfer_b.is_direct_load, + static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load == + ALGORITHM.transfer.b.lds_transfer.is_direct_load, "A and B block transfers must both be direct load or not."); - static constexpr bool IS_DIRECT_LOAD = ALGORITHM.block_transfer.lds_transfer_a.is_direct_load; + static constexpr bool IS_DIRECT_LOAD = ALGORITHM.transfer.a.lds_transfer.is_direct_load; static constexpr auto FWD_CONV_SPECIALIZATION = factory_internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = @@ -583,9 +563,9 @@ struct ConvFactory static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; static constexpr auto A_BLOCK_TRANSFER = - factory_internal::SetFwdConvABlockTransfer(); + factory_internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = - factory_internal::SetFwdConvBBlockTransfer(); + factory_internal::SetFwdConvBlockTransfer(); static constexpr auto C_BLOCK_TRANSFER = factory_internal::SetCBlockTransfer(); static constexpr auto BLOCK_GEMM = factory_internal::SetBlockGemm(); @@ -681,9 +661,9 @@ struct ConvFactory static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; static constexpr auto A_BLOCK_TRANSFER = - factory_internal::SetFwdConvABlockTransfer(); + factory_internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = - factory_internal::SetFwdConvBBlockTransfer(); + factory_internal::SetFwdConvBlockTransfer(); static constexpr auto C_BLOCK_TRANSFER = factory_internal::SetCBlockTransfer(); @@ -780,9 +760,9 @@ struct ConvFactory static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = factory_internal::SetGridwiseGemmPipelineVersion(); static constexpr auto A_BLOCK_TRANSFER = - factory_internal::SetFwdConvABlockTransfer(); + factory_internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = - factory_internal::SetFwdConvBBlockTransfer(); + factory_internal::SetFwdConvBlockTransfer(); static constexpr auto C_BLOCK_TRANSFER = factory_internal::SetCBlockTransfer(); @@ -884,7 +864,7 @@ struct ConvFactory using M1N1ThreadClusterN1Xs = to_sequence_v; // A Block Transfer from descriptor - K0_M0_M1_K1 tensor format - static constexpr auto DL_A_TRANSFER = ALGORITHM.block_transfer_a; + static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a.block_transfer; using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = to_sequence_v; using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = @@ -900,7 +880,7 @@ struct ConvFactory to_sequence_v; // B Block Transfer from descriptor - K0_N0_N1_K1 tensor format - static constexpr auto DL_B_TRANSFER = ALGORITHM.block_transfer_b; + static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b.block_transfer; using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = to_sequence_v; using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = @@ -916,7 +896,7 @@ struct ConvFactory to_sequence_v; // C Thread Transfer from descriptor - static constexpr auto DL_C_TRANSFER = ALGORITHM.epilogue_c; + static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c.epilogue; using CThreadTransferSrcDstAccessOrder = to_sequence_v; static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim; static constexpr ck::index_t CThreadTransferDstScalarPerVector = @@ -998,9 +978,9 @@ struct ConvFactory static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); static constexpr auto GRIDWISE_GEMM = BASE_ALGORITHM.gridwise_gemm; static constexpr auto A_BLOCK_TRANSFER = - factory_internal::SetFwdConvABlockTransfer(); + factory_internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = - factory_internal::SetFwdConvBBlockTransfer(); + factory_internal::SetFwdConvBlockTransfer(); static constexpr auto C_BLOCK_TRANSFER = factory_internal::SetCBlockTransfer(); diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 5044d223de..1bac03e050 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -39,8 +39,8 @@ add_ck_builder_test(test_ckb_get_instance_string add_ck_builder_test(test_ckb_build_fwd_instances conv/test_ckb_conv_fwd_1d_fp16.cpp conv/test_ckb_conv_fwd_1d_bf16.cpp - conv/test_ckb_conv_fwd_1d_i8.cpp - conv/test_ckb_conv_fwd_2d_fp8.cpp + conv/test_ckb_conv_fwd_1d_i8.cpp + conv/test_ckb_conv_fwd_2d_fp8.cpp conv/test_ckb_conv_fwd_2d_bf16.cpp conv/test_ckb_conv_fwd_2d_fp16.cpp conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -78,7 +78,7 @@ add_ck_builder_test(test_ckb_conv_description function(collect_test_ckb_targets result_var) # Get all targets in current directory get_directory_property(all_targets BUILDSYSTEM_TARGETS) - + set(test_ckb_targets) foreach(target ${all_targets}) # Check if target name starts with "test_ckb" @@ -87,7 +87,7 @@ function(collect_test_ckb_targets result_var) list(APPEND test_ckb_targets ${target}) endif() endforeach() - + set(${result_var} ${test_ckb_targets} PARENT_SCOPE) endfunction() diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp index bb0c767bbd..e37d77e202 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} .with_thread_block(FwdThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_transfer(FwdTransfer_4x64x1) .with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v2_intrawave); diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp index d391c1a74d..8da5d7c0b0 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} .with_thread_block(FwdThreadBlock_64_64x32x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_block_transfer(FwdBlockTransfer_4x16x1) + .with_transfer(FwdTransfer_4x16x1) .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 2, PipelineScheduler::DEFAULT); diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp index 7206c768d8..c13efff952 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{} .with_thread_block(FwdThreadBlock_128_64x64x64) .with_gemm_config(FwdGemmParams_Wmma_2x1_per_wave) - .with_block_transfer(FwdBlockTransfer_4x32x1) + .with_transfer(FwdTransfer_4x32x1) .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 0, PipelineScheduler::DEFAULT); diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp index d0bc1d7a6d..b65e3eb338 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} .with_thread_block(FwdThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_transfer(FwdTransfer_4x64x1) .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v1_intrawave); @@ -50,7 +50,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} .with_thread_block(FwdThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_transfer(FwdTransfer_4x64x1) .with_specializations(ConvFwdSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v5_intrawave); diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp index 0a337f3a7b..d06b3e1f90 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp @@ -23,8 +23,7 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_dl_thread_config(DlThreadConfig_16x2x4x4x1) .with_dl_thread_cluster(DlThreadCluster_8x2) - .with_dl_block_transfer(DlBlockTransferAB, DlBlockTransferAB) - .with_dl_epilogue(DlEpilogueC); + .with_dl_transfer(DlFwdTransfer); using Builder = ConvBuilder; run_test( @@ -48,8 +47,7 @@ TEST(FwdConvInstances, GemmSpecialization::MNKPadding) .with_dl_thread_config(DlThreadConfig_16x2x4x4x1) .with_dl_thread_cluster(DlThreadCluster_8x2) - .with_dl_block_transfer(DlBlockTransferAB, DlBlockTransferAB) - .with_dl_epilogue(DlEpilogueC); + .with_dl_transfer(DlFwdTransfer); using Builder = ConvBuilder; run_test( diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp index 798c6cfa2d..470a58fb2a 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp @@ -22,7 +22,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} .with_thread_block(FwdThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_transfer(FwdTransfer_4x64x1) .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v3_intrawave); diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp index a8313ff510..554b8565ae 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -22,7 +22,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} .with_thread_block(FwdThreadBlock_256_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_transfer(FwdTransfer_4x64x1) .with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v4_intrawave); diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp index 39319bb79e..c26338b579 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} .with_thread_block(FwdThreadBlock_256_256x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave) - .with_block_transfer(FwdBlockTransfer_4x64x1_fp8) + .with_transfer(FwdTransfer_4x64x1_fp8) .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp index 6c43678bf1..73537dae0c 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, .base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} .with_thread_block(FwdThreadBlock_256_256x128x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_block_transfer(FwdBlockTransfer_4x16x1) + .with_transfer(FwdTransfer_4x16x1) .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; @@ -50,7 +50,7 @@ TEST( .base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} .with_thread_block(FwdThreadBlock_128_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_block_transfer(FwdBlockTransfer_4x16x1) + .with_transfer(FwdTransfer_4x16x1) .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp index 2392d1efff..9140a1ab51 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} .with_thread_block(FwdThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_transfer(FwdTransfer_4x64x1) .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v3_intrawave); diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp index 52e153098e..42fa8d8be1 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} .with_thread_block(FwdThreadBlock_256_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_transfer(FwdTransfer_4x64x1) .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v4_intrawave); diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp index 5d1656924c..c630f52c23 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} .with_thread_block(FwdThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_transfer(FwdTransfer_4x64x1) .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v1_intrawave); diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index fd756cf06e..e7c499d9b9 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -103,18 +103,25 @@ struct AccessOrder }; static_assert(AccessOrderDescriptor); -struct BlockTransferABC +struct TransferAB { - BlockTransfer block_transfer_a; - BlockTransfer block_transfer_b; - ThreadCluster thread_cluster_dims_c; - LdsTransfer lds_transfer_a; - LdsTransfer lds_transfer_b; - Epilogue epilogue_c; - AccessOrder block_transfer_access_order_a; - AccessOrder block_transfer_access_order_b; - AccessOrder src_access_order_a; - AccessOrder src_access_order_b; + BlockTransfer block_transfer; + LdsTransfer lds_transfer; + AccessOrder block_transfer_access_order; + AccessOrder src_access_order; +}; + +struct TransferC +{ + ThreadCluster thread_cluster_dims; + Epilogue epilogue; +}; + +struct TransferABC +{ + TransferAB a; + TransferAB b; + TransferC c; }; // DL-specific descriptors @@ -172,9 +179,9 @@ struct WmmaGemm_ GridwiseWmmaGemm gridwise_gemm; }; -struct BlockTransfer_ +struct Transfer_ { - BlockTransferABC block_transfer; + TransferABC transfer; }; struct ConvSpecialization_ @@ -205,15 +212,26 @@ struct DlThreadCluster_ DlThreadCluster thread_cluster; }; -struct DlBlockTransfer_ +struct DlBlockTransferAB { - DlBlockTransfer block_transfer_a; - DlBlockTransfer block_transfer_b; + DlBlockTransfer block_transfer; }; -struct DlEpilogue_ +struct DlBlockTransferC { - DlEpilogue epilogue_c; + DlEpilogue epilogue; +}; + +struct DlTransferABC +{ + DlBlockTransferAB a; + DlBlockTransferAB b; + DlBlockTransferC c; +}; + +struct DlTransfer_ +{ + DlTransferABC transfer; }; // Specialization wrapper for large tensor support @@ -255,12 +273,12 @@ struct ConvAlgorithmTemplate : Components... return result; } - template - constexpr auto with_block_transfer(const BT& bt) const + template + constexpr auto with_transfer(const T& t) const { - static_assert(std::is_base_of_v); - auto result = *this; - result.block_transfer = bt; + static_assert(std::is_base_of_v); + auto result = *this; + result.transfer = t; return result; } @@ -313,21 +331,12 @@ struct ConvAlgorithmTemplate : Components... return result; } - template - constexpr auto with_dl_block_transfer(const BTA& bta, const BTB& btb) const + template + constexpr auto with_dl_transfer(const T& t) const { - static_assert(std::is_base_of_v); - auto result = *this; - result.block_transfer_a = bta; - result.block_transfer_b = btb; - return result; - } - - constexpr auto with_dl_epilogue(const DlEpilogue& epi) const - { - static_assert(std::is_base_of_v); - auto result = *this; - result.epilogue_c = epi; + static_assert(std::is_base_of_v); + auto result = *this; + result.transfer = t; return result; } }; @@ -335,20 +344,19 @@ struct ConvAlgorithmTemplate : Components... // Algorithm types using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate; using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate; using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = ConvAlgorithmTemplate; + DlTransfer_>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = LargeTensorWrapper; diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index c2f7039348..941fac5389 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -64,30 +64,39 @@ struct DefaultAlgorithm .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}; - ckb::test::BlockTransferABC block_transfer{ - .block_transfer_a = {.k0 = 4, .m_n = 256, .k1 = 8}, - .block_transfer_b = {.k0 = 4, .m_n = 256, .k1 = 8}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 32, - .n_block = 1, - .n_wave_per_xdl = 8}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = true, - .lds_padding = false}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = true, - .lds_padding = false}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {.order = {0, 1, 2}}, - .block_transfer_access_order_b = {.order = {0, 1, 2}}, - .src_access_order_a = {.order = {0, 1, 2}}, - .src_access_order_b = {.order = {0, 1, 2}}}; + ckb::test::TransferABC transfer{ + .a = + { + .block_transfer = {.k0 = 4, .m_n = 256, .k1 = 8}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = true, + .lds_padding = false}, + .block_transfer_access_order = {.order = {0, 1, 2}}, + .src_access_order = {.order = {0, 1, 2}}, + + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 256, .k1 = 8}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = true, + .lds_padding = false}, + .block_transfer_access_order = {.order = {0, 1, 2}}, + .src_access_order = {.order = {0, 1, 2}}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, + .epilogue = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + }, + }; ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT; ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 7f2acce9c8..55d21e0fd2 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -25,109 +25,152 @@ constexpr DlBlockTransfer DlBlockTransferAB{.thread_slice_lengths = {8, .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, .dst_vector_tensor_lengths = {1, 1, 1, 2}}; -constexpr DlEpilogue DlEpilogueC{.src_dst_access_order = {0, 1, 2, 3, 4, 5}, - .src_dst_vector_dim = 5, - .dst_scalar_per_vector = 4}; +constexpr DlTransferABC DlFwdTransfer{.a = + { + .block_transfer = DlBlockTransferAB, + }, + .b = + { + .block_transfer = DlBlockTransferAB, + }, + .c = { + .epilogue = {.src_dst_access_order = {0, 1, 2, 3, 4, 5}, + .src_dst_vector_dim = 5, + .dst_scalar_per_vector = 4}, + }}; -constexpr BlockTransferABC FwdBlockTransfer_4x64x1{ - .block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1}, - .block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 32, - .n_block = 1, - .n_wave_per_xdl = 8}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = false}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = false}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {1, 0, 2}, - .block_transfer_access_order_b = {1, 0, 2}, - .src_access_order_a = {1, 0, 2}, - .src_access_order_b = {1, 0, 2}}; +constexpr TransferABC FwdTransfer_4x64x1{ + .a = + { + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = false}, + .block_transfer_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = false}, + .block_transfer_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, + .epilogue = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + }, +}; -constexpr BlockTransferABC FwdBlockTransfer_4x64x1_fp8{ - .block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1}, - .block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 32, - .n_block = 1, - .n_wave_per_xdl = 8}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {1, 0, 2}, - .block_transfer_access_order_b = {1, 0, 2}, - .src_access_order_a = {1, 0, 2}, - .src_access_order_b = {1, 0, 2}}; +constexpr TransferABC FwdTransfer_4x64x1_fp8{ + .a = + { + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .block_transfer_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .block_transfer_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, + .epilogue = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + }, +}; -constexpr BlockTransferABC FwdBlockTransfer_4x16x1{ - .block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1}, - .block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 16, - .n_block = 1, - .n_wave_per_xdl = 4}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {1, 0, 2}, - .block_transfer_access_order_b = {1, 0, 2}, - .src_access_order_a = {1, 0, 2}, - .src_access_order_b = {1, 0, 2}}; +constexpr TransferABC FwdTransfer_4x16x1{ + .a = + { + .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .block_transfer_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .block_transfer_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 16, .n_block = 1, .n_wave_per_xdl = 4}, + .epilogue = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, -constexpr BlockTransferABC FwdBlockTransfer_4x32x1{ - .block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1}, - .block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 32, - .n_block = 1, - .n_wave_per_xdl = 4}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 16, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 16, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {1, 0, 2}, - .block_transfer_access_order_b = {1, 0, 2}, - .src_access_order_a = {1, 0, 2}, - .src_access_order_b = {1, 0, 2}}; + }, +}; + +constexpr TransferABC FwdTransfer_4x32x1{ + .a = + { + .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .block_transfer_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .block_transfer_access_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 4}, + .epilogue = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + }, +}; constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4};