diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index cf31ee64c0..9b120bae9a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -281,7 +281,7 @@ struct LargeTensorAlgorithm : public FwdXdlAlgorithmBase }; template -struct BwdXdlAlgorithm { +struct BwdXdlAlgorithmBase { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) CHECK_CONCEPT(T, SpecifiesBlockTransfer4D) @@ -290,8 +290,6 @@ struct BwdXdlAlgorithm { CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm) CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) - CHECK_CONCEPT(T, TransposeTransferWellDefinedIfProvided) - CHECK_CONCEPT(T, SpecifiesGenericInstance) static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; @@ -301,16 +299,13 @@ struct BwdXdlAlgorithm { static constexpr bool c6 = c_SpecifiesSourceAccessOrder; static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm; static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; - static constexpr bool c9 = c_TransposeTransferWellDefinedIfProvided; - static constexpr bool c10 = c_SpecifiesGenericInstance; static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10; + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8; } static consteval auto message() -> std::string { - return std::string("\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdXdl Algorithm:\n") + + return DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + DIAGNOSTIC_LINE(SpecifiesBlockTransfer4D) + @@ -318,49 +313,45 @@ struct BwdXdlAlgorithm { DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) + - DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + - DIAGNOSTIC_LINE(TransposeTransferWellDefinedIfProvided) + + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization); + } +}; + +template +struct BwdXdlAlgorithm : public BwdXdlAlgorithmBase{ + CHECK_CONCEPT(T, SpecifiesTransposeTransfer) + CHECK_CONCEPT(T, SpecifiesGenericInstance) + + static constexpr bool c9 = c_SpecifiesTransposeTransfer; + static constexpr bool c10 = c_SpecifiesGenericInstance; + + static consteval bool is_valid() { + return c9 && c10 && BwdXdlAlgorithmBase::is_valid(); + } + + static consteval auto message() -> std::string { + return std::string("\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdXdl Algorithm:\n") + + BwdXdlAlgorithmBase::message() + + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + DIAGNOSTIC_LINE(SpecifiesGenericInstance); } }; template -struct BwdMultiDXdlAlgorithm { - CHECK_CONCEPT(T, ConvAlgorithmDescriptor) - CHECK_CONCEPT(T, SpecifiesThreadBlock) - CHECK_CONCEPT(T, SpecifiesBlockTransfer4D) - CHECK_CONCEPT(T, SpecifiesLdsTransfer) - CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) - CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) - CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm) - CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) +struct BwdMultiDXdlAlgorithm : public BwdXdlAlgorithmBase{ CHECK_CONCEPT(T, SpecifiesMultipleDSupport) - static constexpr bool c1 = c_ConvAlgorithmDescriptor; - static constexpr bool c2 = c_SpecifiesThreadBlock; - static constexpr bool c3 = c_SpecifiesBlockTransfer4D; - static constexpr bool c4 = c_SpecifiesLdsTransfer; - static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; - static constexpr bool c6 = c_SpecifiesSourceAccessOrder; - static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm; - static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; static constexpr bool c9 = c_SpecifiesMultipleDSupport; static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; + return c9 && BwdXdlAlgorithmBase::is_valid(); } static consteval auto message() -> std::string { return std::string("\n=== Backward XDL Algorithm Diagnostic (closest match) ===\n" "Concepts for BwdXdl Algorithm:\n") + - DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + - DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransfer4D) + - DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + - DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + - DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + - DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) + - DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + + BwdXdlAlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesMultipleDSupport); } }; @@ -448,7 +439,7 @@ struct BwdTwoStageXdlAlgorithm : public BwdXdlV3AlgorithmBase{ }; template -struct BwdWmmaAlgorithm { +struct BwdWmmaAlgorithmBase { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) CHECK_CONCEPT(T, SpecifiesBlockTransfer) @@ -457,10 +448,6 @@ struct BwdWmmaAlgorithm { CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) CHECK_CONCEPT(T, SpecifiesGridwiseWmmaGemm) CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) - CHECK_CONCEPT(T, SpecifiesNumPrefetchStages) - CHECK_CONCEPT(T, SpecifiesLoopScheduler) - CHECK_CONCEPT(T, SpecifiedGridwiseGemmPipeline) - CHECK_CONCEPT(T, SpecifiesGenericInstance) static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; @@ -470,18 +457,13 @@ struct BwdWmmaAlgorithm { static constexpr bool c6 = c_SpecifiesSourceAccessOrder; static constexpr bool c7 = c_SpecifiesGridwiseWmmaGemm; static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; - static constexpr bool c9 = c_SpecifiesNumPrefetchStages; - static constexpr bool c10 = c_SpecifiesLoopScheduler; - static constexpr bool c11 = c_SpecifiedGridwiseGemmPipeline; - static constexpr bool c12 = c_SpecifiesGenericInstance; static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10 && c11 && c12; + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8; } static consteval auto message() -> std::string { - return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdWmma Algorithm:\n") + + return DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + @@ -489,7 +471,30 @@ struct BwdWmmaAlgorithm { DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) + - DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization); + } +}; + +template +struct BwdWmmaAlgorithm : public BwdWmmaAlgorithmBase { + CHECK_CONCEPT(T, SpecifiesNumPrefetchStages) + CHECK_CONCEPT(T, SpecifiesLoopScheduler) + CHECK_CONCEPT(T, SpecifiedGridwiseGemmPipeline) + CHECK_CONCEPT(T, SpecifiesGenericInstance) + + static constexpr bool c9 = c_SpecifiesNumPrefetchStages; + static constexpr bool c10 = c_SpecifiesLoopScheduler; + static constexpr bool c11 = c_SpecifiedGridwiseGemmPipeline; + static constexpr bool c12 = c_SpecifiesGenericInstance; + + static consteval bool is_valid() { + return c9 && c10 && c11 && c12 && BwdWmmaAlgorithmBase::is_valid(); + } + + static consteval auto message() -> std::string { + return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdWmma Algorithm:\n") + + BwdWmmaAlgorithmBase::message() + DIAGNOSTIC_LINE(SpecifiesNumPrefetchStages) + DIAGNOSTIC_LINE(SpecifiesLoopScheduler) + DIAGNOSTIC_LINE(SpecifiedGridwiseGemmPipeline) + @@ -497,6 +502,27 @@ struct BwdWmmaAlgorithm { } }; +template +struct BwdMultiDWmmaAlgorithm : public BwdWmmaAlgorithmBase { + CHECK_CONCEPT(T, SpecifiesBlockGemm) + CHECK_CONCEPT(T, SpecifiesMultipleDSupport) + + static constexpr bool c9 = c_SpecifiesBlockGemm; + static constexpr bool c10 = c_SpecifiesMultipleDSupport; + + static consteval bool is_valid() { + return c9 && c10 && BwdWmmaAlgorithmBase::is_valid(); + } + + static consteval auto message() -> std::string { + return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdMultiDWmma Algorithm:\n") + + BwdWmmaAlgorithmBase::message() + + DIAGNOSTIC_LINE(SpecifiesBlockGemm) + + DIAGNOSTIC_LINE(SpecifiesMultipleDSupport); + } +}; + template struct BwdWmmaV3AlgorithmBase { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) @@ -508,7 +534,7 @@ struct BwdWmmaV3AlgorithmBase { CHECK_CONCEPT(T, SpecifiesGridwiseWmmaGemm) CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) CHECK_CONCEPT(T, SpecifiesBlockGemm) - CHECK_CONCEPT(T, TransposeTransferWellDefinedIfProvided) + CHECK_CONCEPT(T, SpecifiesTransposeTransfer) static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; @@ -519,7 +545,7 @@ struct BwdWmmaV3AlgorithmBase { static constexpr bool c7 = c_SpecifiesGridwiseWmmaGemm; static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; static constexpr bool c9 = c_SpecifiesBlockGemm; - static constexpr bool c10 = c_TransposeTransferWellDefinedIfProvided; + static constexpr bool c10 = c_SpecifiesTransposeTransfer; static consteval bool is_valid() { return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10; @@ -536,7 +562,7 @@ struct BwdWmmaV3AlgorithmBase { DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + DIAGNOSTIC_LINE(SpecifiesBlockGemm) + - DIAGNOSTIC_LINE(TransposeTransferWellDefinedIfProvided); + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer); } }; @@ -670,6 +696,12 @@ consteval int count_matches_bwd_wmma() { return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12; } +template +consteval int count_matches_bwd_multi_d_wmma() { + using Alg = BwdMultiDWmmaAlgorithm; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12; +} + template consteval int count_matches_bwd_wmma_v3() { using Alg = BwdWmmaV3Algorithm; @@ -785,6 +817,7 @@ consteval void diagnose_bwd_weight_algorithm_signature() constexpr int wmma_v3_matches = count_matches_bwd_wmma_v3(); constexpr int two_stage_wmma_v3_matches = count_matches_bwd_two_stage_wmma_v3(); constexpr int wmma_matches = count_matches_bwd_wmma(); + constexpr int multi_d_wmma_matches = count_matches_bwd_multi_d_wmma(); // Check whether we have XDL or WMMA algorithm if constexpr (SpecifiesGridwiseBwdXdlGemm) @@ -818,7 +851,9 @@ consteval void diagnose_bwd_weight_algorithm_signature() else if constexpr (SpecifiesGridwiseWmmaGemm) { constexpr int max_1 = wmma_v3_matches > two_stage_wmma_v3_matches ? wmma_v3_matches : two_stage_wmma_v3_matches; - constexpr int max_matches = max_1 > wmma_matches ? max_1 : wmma_matches; + constexpr int max_2 = max_1 > wmma_matches ? max_1 : wmma_matches; + constexpr int max_matches = multi_d_wmma_matches > max_2 ? multi_d_wmma_matches : max_2; + if constexpr (max_matches == wmma_v3_matches) { using Alg = BwdWmmaV3Algorithm; static_assert(Alg::is_valid(), Alg::message()); @@ -831,6 +866,10 @@ consteval void diagnose_bwd_weight_algorithm_signature() using Alg = BwdWmmaAlgorithm; static_assert(Alg::is_valid(), Alg::message()); } + else if constexpr (max_matches == multi_d_wmma_matches) { + using Alg = BwdMultiDWmmaAlgorithm; + static_assert(Alg::is_valid(), Alg::message()); + } } else { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 4e1a4766fd..533e775272 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -186,6 +186,11 @@ constexpr auto make_conv_instance() { return typename ConvBwdWeightWmmaFactory::Instance{}; } + else if constexpr (BwdMultiDWmmaAlgorithm::is_valid()) + { + static_assert(false, + "Backward weight convolution with multi-D WMMA algorithm is not yet supported."); + } else { diagnose_bwd_weight_algorithm_signature(); diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 179b73b784..b2f5970d2e 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -178,6 +178,7 @@ if (CK_USE_WMMA) conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp + conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp ) endif() diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp new file mode 100644 index 0000000000..e050ffad4e --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle{} + .with_thread_block(cku::ThreadBlock_256_128x128x8) + .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) + .with_transfer(cku::BwdTransfer_4x64x1) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_block_gemm(cku::BlockGemmDesc_v1_intrawave); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_2DFp16_MultiD_Wmma_Shuffle_GNHWC, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle", + expected_transfer_parameters, + "Default", + "GNHWC,GKYXC,GNHWK", + "PassThrough,PassThrough,PassThrough", + "fp16,fp16>"}); // check compute types +} diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index f53bdaa6ab..b126b5af8d 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -547,7 +547,7 @@ struct ConvAlgorithmTemplate : Components... } }; -// Algorithm types +// Fwd algorithm types using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = ConvAlgorithmTemplate, ConvSpecializationFwd_, Prefetch_>; @@ -568,6 +568,7 @@ using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = ConvAlgorithmTemplate, ConvSpecializationFwd_, Prefetch_, LargeTensorSpecialization_>; +// CK Tile algorithm using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, TransposeParams_>; @@ -607,5 +609,8 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 = using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, GridGemm_, Prefetch_>; +using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle = + ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_, MultipleDSpecialization_>; + } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index 56c9e1755f..ad887b2491 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -443,6 +443,16 @@ inline std::string to_string +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + template <> inline std::string to_string( ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 t)