From 75710202ab3466a6eb8d63e2c6903e35ecce484f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 31 Dec 2025 04:32:28 -0500 Subject: [PATCH] Added factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle. --- .../builder/conv_algorithm_concepts.hpp | 17 +++ .../builder/conv_algorithm_diagnostics.hpp | 30 +++++ .../builder/factory/conv_algorithms.hpp | 71 +++++++++++- .../conv_bwd_weight_two_stage_xdl_factory.hpp | 106 ++++++++++++++++++ .../builder/factory/conv_dispatcher.hpp | 24 +--- .../builder/include/ck_tile/builder/types.hpp | 3 +- experimental/builder/test/CMakeLists.txt | 1 + ...conv_bwd_weight_two_stage_xdl_cshuffle.cpp | 44 ++++++++ .../test_ckb_conv_bwd_weight_xdl_cshuffle.cpp | 6 +- .../test/impl/conv_algorithm_types.hpp | 25 ++++- .../test/utils/conv_algorithm_type_utils.hpp | 10 ++ 11 files changed, 310 insertions(+), 27 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp create mode 100644 experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 447bbdad5e..d554f92422 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -321,12 +321,29 @@ concept SpecifiesLargeTensorSupport = requires { requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR; }; +template +concept SpecifiesTwoStageSupport = requires { + { T::specialization } -> std::convertible_to; + requires T::specialization == ConvAlgorithmSpecialization::TWO_STAGE; +}; + +template +concept SpecifiesGenericInstance = !requires { + { T::specialization }; +}; + template concept SpecifiesTransposeTransfer = requires { { T::max_transpose_transfer_src_scalar_per_vector } -> SizeType; { T::max_transpose_transfer_dst_scalar_per_vector } -> SizeType; }; + +template +concept SpecifiesGemmBatchOptions = requires { + { T::num_conv_groups_to_merge } -> SizeType; +}; + /******************************************** */ /* DL-specific descriptors and requirements */ /******************************************** */ diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp index 6613d2d736..7497355224 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -712,6 +712,36 @@ consteval auto detailed_diagnostic_SpecifiesLargeTensorSupport() -> std::string return msg; } +template +consteval auto detailed_diagnostic_SpecifiesTwoStageSupport() -> std::string { + std::string msg; + if constexpr (requires { T::specialization; }) { + using SpecType = decltype(T::specialization); + constexpr bool convertible = std::convertible_to; + msg += " → T::specialization: " + std::string(CHECK_MARK(convertible)) + + (convertible ? "" : std::string(detail::get_type_info())) + "\n"; + + if constexpr (convertible) { + constexpr bool is_two_stage = (T::specialization == ConvAlgorithmSpecialization::TWO_STAGE); + msg += " → specialization == TWO_STAGE: " + std::string(CHECK_MARK(is_two_stage)) + "\n"; + } + } else { + msg += " → T::specialization: [✗] (missing member)\n"; + } + + return msg; +} + +template +consteval auto detailed_diagnostic_SpecifiesGenericInstance() -> std::string { + std::string msg; + if constexpr (requires { T::specialization; }) { + msg += " → T::specialization: [✗] (member should NOT exist for generic instance)\n"; + msg += " → This concept requires the absence of the specialization member\n"; + } + return msg; +} + template consteval auto detailed_diagnostic_SpecifiesTransposeTransfer() -> std::string { std::string msg; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index bf7f0248fd..312e746f8a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -290,6 +290,7 @@ struct BwdXdlV3Algorithm { CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm) CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) CHECK_CONCEPT(T, SpecifiesBlockGemm) + CHECK_CONCEPT(T, SpecifiesGenericInstance) static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; @@ -300,9 +301,10 @@ struct BwdXdlV3Algorithm { static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm; static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; static constexpr bool c9 = c_SpecifiesBlockGemm; + static constexpr bool c10 = c_SpecifiesGenericInstance; static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10; } static consteval auto message() -> std::string { @@ -316,7 +318,58 @@ struct BwdXdlV3Algorithm { DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + - DIAGNOSTIC_LINE(SpecifiesBlockGemm); + DIAGNOSTIC_LINE(SpecifiesBlockGemm) + + DIAGNOSTIC_LINE(SpecifiesGenericInstance); + } +}; + +template +struct BwdTwoStageXdlAlgorithm { + CHECK_CONCEPT(T, ConvAlgorithmDescriptor) + CHECK_CONCEPT(T, SpecifiesThreadBlock) + CHECK_CONCEPT(T, SpecifiesBlockTransfer) + CHECK_CONCEPT(T, SpecifiesLdsTransfer) + CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) + CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) + CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm) + CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) + CHECK_CONCEPT(T, SpecifiesBlockGemm) + CHECK_CONCEPT(T, SpecifiesTransposeTransfer) + CHECK_CONCEPT(T, SpecifiesGemmBatchOptions) + CHECK_CONCEPT(T, SpecifiesTwoStageSupport) + + static constexpr bool c1 = c_ConvAlgorithmDescriptor; + static constexpr bool c2 = c_SpecifiesThreadBlock; + static constexpr bool c3 = c_SpecifiesBlockTransfer; + static constexpr bool c4 = c_SpecifiesLdsTransfer; + static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; + static constexpr bool c6 = c_SpecifiesSourceAccessOrder; + static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm; + static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; + static constexpr bool c9 = c_SpecifiesBlockGemm; + static constexpr bool c10 = c_SpecifiesTransposeTransfer; + static constexpr bool c11 = c_SpecifiesGemmBatchOptions; + static constexpr bool c12 = c_SpecifiesTwoStageSupport; + + static consteval bool is_valid() { + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && & c10 && c11 && c12; + } + + static consteval auto message() -> std::string { + return std::string("\n=== Backward two stage XDL Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdXdlV3 Algorithm:\n") + + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) + + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + + DIAGNOSTIC_LINE(SpecifiesBlockGemm) + + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + + DIAGNOSTIC_LINE(SpecifiesGemmBatchOptions) + + DIAGNOSTIC_LINE(SpecifiesTwoStageSupport); } }; @@ -356,6 +409,12 @@ consteval int count_matches_bwd_xdl_v3() { return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9; } +template +consteval int count_matches_bwd_two_stage_xdl() { + using Alg = BwdTwoStageXdlAlgorithm; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12; +} + template consteval int count_matches_large_tensor() { using Alg = LargeTensorAlgorithm; @@ -417,8 +476,10 @@ consteval void diagnose_bwd_weight_algorithm_signature() { constexpr int xdl_matches = count_matches_bwd_xdl(); constexpr int xdl_v3_matches = count_matches_fwd_xdl_v3(); + constexpr int two_stage_xdl_matches = count_matches_bwd_two_stage_xdl(); - constexpr int max_matches = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; + constexpr int max1 = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; + constexpr int max_matches = max1 > two_stage_xdl_matches ? max1 : two_stage_xdl_matches; if constexpr (max_matches == xdl_matches) { using Alg = BwdXdlAlgorithm; @@ -428,6 +489,10 @@ consteval void diagnose_bwd_weight_algorithm_signature() using Alg = BwdXdlV3Algorithm; static_assert(Alg::is_valid(), Alg::message()); } + else if constexpr (max_matches == two_stage_xdl_matches) { + using Alg = BwdTwoStageXdlAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } else { // This should never happen static_assert(false, "Internal Error: No matching algorithm variant found for diagnostics."); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp new file mode 100644 index 0000000000..b9852127e8 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp @@ -0,0 +1,106 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance +// of a grouped forward convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightTwoStageXdlFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits3D, "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, "Invalid A source access order"); + static_assert(AccessOrderLimits3D, "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + ALGORITHM.num_conv_groups_to_merge, + typename Types::InComputeType, + typename Types::WeiComputeType, + ALGORITHM.max_transpose_transfer_src_scalar_per_vector, + ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 01c0fb9c56..1812f1a0ff 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -49,12 +49,6 @@ #pragma once -// Disable pragma message warnings for factory selection diagnostics -#ifdef __clang__ - #pragma clang diagnostic push - #pragma clang diagnostic ignored "-W#pragma-messages" -#endif - #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/types.hpp" @@ -71,6 +65,7 @@ #include "ck_tile/builder/factory/conv_tile_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp" namespace ck_tile::builder::factory { @@ -103,34 +98,28 @@ constexpr auto make_conv_instance() // CK Tile supports common factory for each direction if constexpr(TileAlgorithm::is_valid()) { - #pragma message("[CK Builder] Using ConvTileFactory...") return typename ConvTileFactory::Instance{}; } else if constexpr(ConvDirectionIsForward) { if constexpr(FwdXdlV3Algorithm::is_valid()) { - #pragma message("[CK Builder] Using ConvFwdXdlV3Factory...") return typename ConvFwdXdlV3Factory::Instance{}; } else if constexpr(FwdXdlAlgorithm::is_valid()) { - #pragma message("[CK Builder] Using ConvFwdXdlFactory...") return typename ConvFwdXdlFactory::Instance{}; } else if constexpr(FwdWmmaAlgorithm::is_valid()) { - #pragma message("[CK Builder] Using ConvFwdWmmaFactory...") return typename ConvFwdWmmaFactory::Instance{}; } else if constexpr(FwdDlAlgorithm::is_valid()) { - #pragma message("[CK Builder] Using ConvFwdDlFactory...") return typename ConvFwdDlFactory::Instance{}; } else if constexpr(LargeTensorAlgorithm::is_valid()) { - #pragma message("[CK Builder] Using ConvFwdLargeTensorFactory...") return typename ConvFwdLargeTensorFactory::Instance{}; } else @@ -148,14 +137,16 @@ constexpr auto make_conv_instance() { if constexpr (BwdXdlAlgorithm::is_valid()) { - #pragma message("[CK Builder] Using ConvBwdWeightXdlFactory...") return typename ConvBwdWeightXdlFactory::Instance{}; } else if constexpr (BwdXdlV3Algorithm::is_valid()) { - #pragma message("[CK Builder] Using ConvBwdWeightXdlV3Factory...") return typename ConvBwdWeightXdlV3Factory::Instance{}; } + else if constexpr (BwdTwoStageXdlAlgorithm::is_valid()) + { + return typename ConvBwdWeightTwoStageXdlFactory::Instance{}; + } else { diagnose_bwd_weight_algorithm_signature(); @@ -171,8 +162,3 @@ constexpr auto make_conv_instance() } } // namespace ck_tile::builder::factory - -// Re-enable pragma message warnings -#ifdef __clang__ - #pragma clang diagnostic pop -#endif diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index ade9484640..c44f3368ae 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -232,7 +232,8 @@ enum class PipelineScheduler enum class ConvAlgorithmSpecialization { - LARGE_TENSOR + LARGE_TENSOR, + TWO_STAGE }; // toString methods for enum classes diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 6a54edab9a..a3d08e82ed 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -151,6 +151,7 @@ target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility) add_ck_builder_test(test_ckb_build_bwd_weight_instances conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp + conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp ) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp new file mode 100644 index 0000000000..9a8b9573fa --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_block_gemm(cku::BlockGemmDesc_v2_intrawave) + .with_num_conv_groups_to_merge(2) + .with_transpose_params(2, 4); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_2DBf16_TwoStage_CShuffle, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + cku::run_test({"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle", + expected_transfer_parameters, + "Default", + "GNHWC,GKYXC,GNHWK", + "PassThrough,PassThrough,PassThrough", + "Intrawave,v2", // pipeline versions + "bf16,bf16,2,4>"}); // compute types and transpose params +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp index ad11eba693..892f1d35ef 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp @@ -23,7 +23,8 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CSh .with_thread_block(cku::ThreadBlock_256_128x128x8) .with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave) .with_transfer(cku::BwdTransfer_4x64x1) - .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT); + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_transpose_params(2, 2); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; @@ -35,5 +36,6 @@ TEST(BwdWeight_2DFp16_CShuffle_GNHWC, Create) expected_transfer_parameters, "Default", "GNHWC,GKYXC,GNHWK", - "PassThrough,PassThrough,PassThrough"}); + "PassThrough,PassThrough,PassThrough", + "fp16,fp16,2,2>"}); // check compute types and transpose params } diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index b045d185e2..d003440935 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -243,6 +243,11 @@ struct TransposeParams_ size_t max_transpose_transfer_dst_scalar_per_vector{1}; }; +struct GemmBatchOptions_ +{ + size_t num_conv_groups_to_merge{1}; +}; + struct BlockGemm_ { BlockGemm block_gemm; @@ -280,6 +285,11 @@ struct DlTransfer_ DlTransferABC transfer; }; +struct TwoStageSpecialization_ +{ + static constexpr ConvAlgorithmSpecialization specialization = ConvAlgorithmSpecialization::TWO_STAGE; +}; + // Specialization wrapper for large tensor support template struct LargeTensorWrapper @@ -433,8 +443,8 @@ struct ConvAlgorithmTemplate : Components... return result; } - constexpr auto with_transpose_params(bool max_src_scalar_per_vector, - bool max_dst_scalar_per_vector) const + constexpr auto with_transpose_params(size_t max_src_scalar_per_vector, + size_t max_dst_scalar_per_vector) const { static_assert(std::is_base_of_v); auto result = *this; @@ -443,6 +453,14 @@ struct ConvAlgorithmTemplate : Components... return result; } + constexpr auto with_num_conv_groups_to_merge(size_t num_groups_to_merge) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.num_conv_groups_to_merge = num_groups_to_merge; + return result; + } + template constexpr auto with_block_gemm(const BG& bg) const { @@ -555,6 +573,9 @@ using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, TransposeParams_>; +using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle = + ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_, TransposeParams_, GemmBatchOptions_, TwoStageSpecialization_>; + using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_>; diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index cf13f39391..8f530600ac 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -397,4 +397,14 @@ inline std::string to_string +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + } // namespace ck_tile::builder::test