From 9e0594f272bdaaf39599ffeaf16dfd92e377569d Mon Sep 17 00:00:00 2001 From: Kevin Abraham Date: Fri, 30 Jan 2026 21:00:04 +0000 Subject: [PATCH] first instance of bwd data factory --- .../builder/conv_algorithm_concepts.hpp | 19 +++ .../builder/factory/conv_algorithms.hpp | 12 +- .../conv_bwd_data_multi_d_xdl_factory.hpp | 113 ++++++++++++++++++ .../builder/factory/conv_dispatcher.hpp | 18 ++- .../factory/helpers/ck/conv_tuning_params.hpp | 21 ++++ experimental/builder/test/CMakeLists.txt | 1 + ...ckb_conv_bwd_data_multi_d_xdl_cshuffle.cpp | 44 +++++++ .../test/impl/conv_algorithm_types.hpp | 55 +++++++++ .../test/utils/ckb_conv_test_configs.hpp | 20 ++++ .../test/utils/conv_algorithm_type_utils.hpp | 34 ++++++ 10 files changed, 329 insertions(+), 8 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/factory/conv_bwd_data_multi_d_xdl_factory.hpp create mode 100644 experimental/builder/test/conv/ck/test_ckb_conv_bwd_data_multi_d_xdl_cshuffle.cpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 29a04d9b6c..bb87e50943 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -187,6 +187,14 @@ concept GridwiseBwdXdlGemmDescriptor = requires(T t) { { t.xdl_params } -> GridwiseXdlGemmDescriptor; }; +// Concept to check if a struct specifies gridwise XDL GEMM info. +template +concept GridwiseBwdDataXdlGemmDescriptor = requires(T t) { + { t.ak1 } -> SizeType; + { t.bk1 } -> SizeType; + { t.xdl_params } -> GridwiseXdlGemmDescriptor; +}; + // Concept to check if a struct specifies gridwise XDL GEMM info. template concept SpecifiesGridwiseFwdXdlGemm = requires(T t) { @@ -199,6 +207,12 @@ concept SpecifiesGridwiseBwdXdlGemm = requires(T t) { { t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor; }; +// Concept to check if a struct specifies gridwise XDL GEMM info. +template +concept SpecifiesGridwiseBwdDataXdlGemm = requires(T t) { + { t.gridwise_gemm } -> GridwiseBwdDataXdlGemmDescriptor; +}; + // Concept to check if a struct specifies gridwise WMMA GEMM info. template concept SpecifiesGridwiseWmmaGemm = requires(T t) { @@ -292,6 +306,11 @@ concept SpecifiesBwdWeightConvSpecialization = requires { { T::bwd_weight_specialization } -> std::convertible_to; }; +template +concept SpecifiesBwdDataConvSpecialization = requires { + { T::bwd_data_specialization } -> std::convertible_to; +}; + template concept SpecifiesGemmSpecialization = requires { { T::gemm_specialization } -> std::convertible_to; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index c508126adb..d18650412d 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -29,23 +29,27 @@ concept FwdXdlAlgorithmBase = template concept BwdXdlAlgorithmBase = ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters4D && - SpecifiesGridwiseBwdXdlGemm && SpecifiesBwdWeightConvSpecialization; + (SpecifiesGridwiseBwdXdlGemm || SpecifiesGridwiseBwdDataXdlGemm) && + (SpecifiesBwdWeightConvSpecialization || SpecifiesBwdDataConvSpecialization); template concept BwdXdlV3AlgorithmBase = ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && - SpecifiesGridwiseBwdXdlGemm && SpecifiesBwdWeightConvSpecialization && + (SpecifiesGridwiseBwdXdlGemm || SpecifiesGridwiseBwdDataXdlGemm) && + (SpecifiesBwdWeightConvSpecialization || SpecifiesBwdDataConvSpecialization) && SpecifiesBlockGemm && SpecifiesNumGroupsToMerge; template concept BwdWmmaAlgorithmBase = ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && - SpecifiesGridwiseWmmaGemm && SpecifiesBwdWeightConvSpecialization; + SpecifiesGridwiseWmmaGemm && + (SpecifiesBwdWeightConvSpecialization || SpecifiesBwdDataConvSpecialization); template concept BwdWmmaV3AlgorithmBase = ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && - SpecifiesGridwiseWmmaGemm && SpecifiesBwdWeightConvSpecialization && + SpecifiesGridwiseWmmaGemm && + (SpecifiesBwdWeightConvSpecialization || SpecifiesBwdDataConvSpecialization) && SpecifiesBlockGemm; // Reference algorithm concept diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_data_multi_d_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_data_multi_d_xdl_factory.hpp new file mode 100644 index 0000000000..1b21b92341 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_data_multi_d_xdl_factory.hpp @@ -0,0 +1,113 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_V1 instance +// of a grouped bwd Data convolution kernel. +template + requires ConvDirectionIsBackwardData +struct ConvBwdDataMultiDXdlFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdDataConvSpecialization(); + + static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + + // The backward convolution kernel class instance. + using Instance = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< + SPATIAL_DIM, + typename Layouts::OutLayout, + typename Layouts::WeiLayout, + typename Layouts::DsLayout, + typename Layouts::InLayout, + typename Types::OutDataType, + typename Types::WeiDataType, + typename Types::AccDataType, + typename Types::OutComputeType, + typename Types::DsDataType, + typename Types::InDataType, + typename Ops::OutElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::InElementwiseOp, + BWD_CONV_SPECIALIZATION, + ALGORITHM.DoPadGemmM, + ALGORITHM.DoPadGemmN, + ALGORITHM.num_gemm_k_prefetch_stages, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.ak1, + GRIDWISE_GEMM.bk1, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + LOOP_SCHEDULER, + typename Types::OutComputeType, + typename Types::InComputeType, + ALGORITHM.max_transpose_transfer_src_scalar_per_vector, + ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index e235db4bb0..de289f6d91 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -77,6 +77,7 @@ #include "ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_data_multi_d_xdl_factory.hpp" namespace ck_tile::builder::factory { @@ -151,10 +152,19 @@ constexpr auto make_conv_instance() // Backward data direction (will expand with more algorithms in the future) else if constexpr(ConvDirectionIsBackwardData) { - static_assert(false, - "Backward data convolution: Only reference and tile algorithms supported " - "currently. " - "Optimized kernels (XDL, WMMA, etc.) not yet implemented."); + if constexpr(BwdMultiDXdlAlgorithm) + { + return typename ConvBwdDataMultiDXdlFactory::Instance{}; + } + else + { + static_assert( + false, + "No suitable backward data convolution kernel factory found for the provided " + "ALGORITHM. " + "The ALGORITHM must satisfy requirements for one of: Reference, Tile, XDL V3, XDL, " + "WMMA, DL (NHWC layout), or Large Tensor variant."); + } } // Backward weight direction (will expand with more algorithms in the future) else if constexpr(ConvDirectionIsBackwardWeight) diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index 3b1ea65695..16e1b4e51d 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -5,6 +5,7 @@ #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" @@ -180,4 +181,24 @@ SetBwdWeightConvSpecialization() } } +template +consteval ck::tensor_operation::device::ConvolutionBackwardDataSpecialization +SetBwdDataConvSpecialization() +{ + constexpr auto specialization = ALGORITHM.bwd_data_specialization; + using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; + switch(specialization) + { + case ConvSpecialization::DEFAULT: return ck_conv_spec::Default; + case ConvSpecialization::FILTER_1X1_PAD0: + throw "FILTER_1x1_PAD0 is not supported for backward data convolution."; + case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; + case ConvSpecialization::ODD_C: + throw "FILTER ODD_C is not supported for backward data convolution."; + case ConvSpecialization::FILTER_3x3: + throw "FILTER_3x3 is not supported for backward data convolution."; + default: throw "Unsupported ConvSpecialization"; + } +} + } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 73a682f10c..6b63d221d3 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -178,6 +178,7 @@ set(BWD_WEIGHT_TESTS conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp conv/ck/test_ckb_conv_bwd_weight_dl.cpp conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp + conv/ck/test_ckb_conv_bwd_data_multi_d_xdl_cshuffle.cpp ) if (CK_USE_WMMA) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_data_multi_d_xdl_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_data_multi_d_xdl_cshuffle.cpp new file mode 100644 index 0000000000..62682f828c --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_data_multi_d_xdl_cshuffle.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "gmock/gmock.h" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_DATA, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle{} + .with_thread_block(cku::ThreadBlock_256_128x128x8) + .with_gemm_config(cku::BwdDataGemmParams_Xdl_4x4_per_wave) + .with_transfer(cku::BwdTransfer_4x64x1) + .with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT) + .with_bwd_data_specialization(ckb::ConvSpecialization::DEFAULT) + .with_gemm_pad_params(0, 0) + .with_transpose_params(2, 2); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdData_2DFp16_MultiD_CShuffle_GNHWC, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1", + expected_transfer_parameters, + "Default", + "GNHWC,GKYXC,GNHWK", + "PassThrough,PassThrough,PassThrough", + "fp16,fp16>"}); // check compute types +} diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index f5b9bdc3b5..02fbdee164 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/types.hpp" namespace ck_tile::builder::test { @@ -54,6 +55,13 @@ struct GridwiseBwdXdlGemm }; static_assert(ckb::GridwiseBwdXdlGemmDescriptor); +struct GridwiseBwdDataXdlGemm +{ + size_t ak1 = 0; + size_t bk1 = 0; + XdlParams xdl_params; +}; + // Describe gridwise WMMA GEMM parameters. struct GridwiseWmmaGemm { @@ -209,6 +217,11 @@ struct BwdXdlGemm_ GridwiseBwdXdlGemm gridwise_gemm; }; +struct BwdDataXdlGemm_ +{ + GridwiseBwdDataXdlGemm gridwise_gemm; +}; + struct WmmaGemm_ { GridwiseWmmaGemm gridwise_gemm; @@ -231,12 +244,23 @@ struct ConvSpecializationBwdWeight_ ConvSpecialization bwd_weight_specialization; }; +struct ConvSpecializationBwdData_ +{ + ConvSpecialization bwd_data_specialization; +}; + struct Prefetch_ { size_t num_gemm_k_prefetch_stages; PipelineScheduler loop_scheduler; }; +struct GemmPad_ +{ + size_t DoPadGemmM; + size_t DoPadGemmN; +}; + struct TransposeParams_ { size_t max_transpose_transfer_src_scalar_per_vector{1}; @@ -394,6 +418,10 @@ struct ConvAlgorithmTemplate : Components... { result.gridwise_gemm = gemm; } + else if constexpr(std::is_base_of_v) + { + result.gridwise_gemm = gemm; + } else if constexpr(std::is_base_of_v) { result.gridwise_gemm = gemm; @@ -433,6 +461,14 @@ struct ConvAlgorithmTemplate : Components... return result; } + constexpr auto with_bwd_data_specialization(ConvSpecialization bwd_spec) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.bwd_data_specialization = bwd_spec; + return result; + } + constexpr auto with_prefetch_config(size_t k_prefetch_stages, PipelineScheduler scheduler) const { static_assert(std::is_base_of_v); @@ -452,6 +488,15 @@ struct ConvAlgorithmTemplate : Components... return result; } + constexpr auto with_gemm_pad_params(size_t doPadGemmN_, size_t doPadGemmM_) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.DoPadGemmN = doPadGemmN_; + result.DoPadGemmM = doPadGemmM_; + return result; + } + constexpr auto with_num_conv_groups_to_merge(size_t num_groups_to_merge) const { static_assert(std::is_base_of_v); @@ -683,4 +728,14 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 = BlockGemm_, MultipleDSpecialization_>; +// Bwd Data algorithm types +using ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle = + ConvAlgorithmTemplate, + ConvSpecializationBwdData_, + MultipleDSpecialization_, + Prefetch_, + TransposeParams_, + GemmPad_>; } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index e48f1dd6ba..4c7f0acdef 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -249,6 +249,26 @@ constexpr Transfer<> Transfer_4x32x1{ }, }; +constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_4x4_per_wave{ + .ak1 = 8, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; + +constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_4x2_per_wave{ + .ak1 = 8, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}}; + +constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_2x2_per_wave{ + .ak1 = 8, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}}; + +constexpr GridwiseBwdDataXdlGemm BwdDataGemmParams_Xdl_2x1_per_wave{ + .ak1 = 8, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}}; + constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{ .k1 = 8, .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index 178029e338..7e40940295 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -85,6 +85,15 @@ inline std::string to_string(ThreadBlock t) return oss.str(); } +template <> +inline std::string to_string(GridwiseBwdDataXdlGemm t) +{ + std::ostringstream oss; + oss << t.ak1 << "," << t.bk1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl + << "," << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave; + return oss.str(); +} + template <> inline std::string to_string(GridwiseBwdXdlGemm t) { @@ -283,6 +292,12 @@ inline std::string to_string(BwdXdlGemm_ t) return to_string(t.gridwise_gemm); } +template <> +inline std::string to_string(BwdDataXdlGemm_ t) +{ + return to_string(t.gridwise_gemm); +} + template <> inline std::string to_string(WmmaGemm_ t) { @@ -311,6 +326,14 @@ inline std::string to_string(ConvSpecializationBwd return oss.str(); } +template <> +inline std::string to_string(ConvSpecializationBwdData_ t) +{ + std::ostringstream oss; + oss << to_string(t.bwd_data_specialization); + return oss.str(); +} + template <> inline std::string to_string(Prefetch_ t) { @@ -495,4 +518,15 @@ inline std::string to_string +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); + return oss.str(); +} + } // namespace ck_tile::builder::test