diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 9b120bae9a..800403b16d 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -502,27 +502,6 @@ struct BwdWmmaAlgorithm : public BwdWmmaAlgorithmBase { } }; -template -struct BwdMultiDWmmaAlgorithm : public BwdWmmaAlgorithmBase { - CHECK_CONCEPT(T, SpecifiesBlockGemm) - CHECK_CONCEPT(T, SpecifiesMultipleDSupport) - - static constexpr bool c9 = c_SpecifiesBlockGemm; - static constexpr bool c10 = c_SpecifiesMultipleDSupport; - - static consteval bool is_valid() { - return c9 && c10 && BwdWmmaAlgorithmBase::is_valid(); - } - - static consteval auto message() -> std::string { - return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n" - "Concepts for BwdMultiDWmma Algorithm:\n") + - BwdWmmaAlgorithmBase::message() + - DIAGNOSTIC_LINE(SpecifiesBlockGemm) + - DIAGNOSTIC_LINE(SpecifiesMultipleDSupport); - } -}; - template struct BwdWmmaV3AlgorithmBase { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) @@ -534,7 +513,6 @@ struct BwdWmmaV3AlgorithmBase { CHECK_CONCEPT(T, SpecifiesGridwiseWmmaGemm) CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) CHECK_CONCEPT(T, SpecifiesBlockGemm) - CHECK_CONCEPT(T, SpecifiesTransposeTransfer) static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; @@ -545,10 +523,9 @@ struct BwdWmmaV3AlgorithmBase { static constexpr bool c7 = c_SpecifiesGridwiseWmmaGemm; static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; static constexpr bool c9 = c_SpecifiesBlockGemm; - static constexpr bool c10 = c_SpecifiesTransposeTransfer; static consteval bool is_valid() { - return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9 && c10; + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; } static consteval auto message() -> std::string { @@ -561,26 +538,46 @@ struct BwdWmmaV3AlgorithmBase { DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + DIAGNOSTIC_LINE(SpecifiesGridwiseWmmaGemm) + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + - DIAGNOSTIC_LINE(SpecifiesBlockGemm) + - DIAGNOSTIC_LINE(SpecifiesTransposeTransfer); + DIAGNOSTIC_LINE(SpecifiesBlockGemm); + } +}; + +template +struct BwdMultiDWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase { + CHECK_CONCEPT(T, SpecifiesMultipleDSupport) + + static constexpr bool c10 = c_SpecifiesMultipleDSupport; + + static consteval bool is_valid() { + return c10 && BwdWmmaAlgorithmBase::is_valid(); + } + + static consteval auto message() -> std::string { + return std::string("\n=== Backward WMMA Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdMultiDWmma Algorithm:\n") + + BwdWmmaAlgorithmBase::message() + + DIAGNOSTIC_LINE(SpecifiesMultipleDSupport); } }; template struct BwdWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase { + CHECK_CONCEPT(T, SpecifiesTransposeTransfer) CHECK_CONCEPT(T, SpecifiesGenericInstance) + static constexpr bool c10 = c_SpecifiesTransposeTransfer; static constexpr bool c11 = c_SpecifiesGenericInstance; static consteval bool is_valid() { - return c11 && BwdWmmaV3AlgorithmBase::is_valid(); + return c10 && c11 && BwdWmmaV3AlgorithmBase::is_valid(); } static consteval auto message() -> std::string { return std::string("\n=== Backward WMMA V3 Algorithm Diagnostic (closest match) ===\n" "Concepts for BwdWmmaV3 Algorithm:\n") + BwdWmmaV3AlgorithmBase::message() + + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + DIAGNOSTIC_LINE(SpecifiesGenericInstance); } }; @@ -588,20 +585,23 @@ struct BwdWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase template struct BwdTwoStageWmmaV3Algorithm : public BwdWmmaV3AlgorithmBase { + CHECK_CONCEPT(T, SpecifiesTransposeTransfer) CHECK_CONCEPT(T, SpecifiesTwoStageSupport) CHECK_CONCEPT(T, SpecifiesGemmBatchOptions) + static constexpr bool c10 = c_SpecifiesTransposeTransfer; static constexpr bool c11 = c_SpecifiesTwoStageSupport; static constexpr bool c12 = c_SpecifiesGemmBatchOptions; static consteval bool is_valid() { - return c11 && c12 && BwdWmmaV3AlgorithmBase::is_valid(); + return c10 && c11 && c12 && BwdWmmaV3AlgorithmBase::is_valid(); } static consteval auto message() -> std::string { return std::string("\n=== Backward Two Stage WMMA V3 Algorithm Diagnostic (closest match) ===\n" "Concepts for BwdTwoStageWmmaV3 Algorithm:\n") + BwdWmmaV3AlgorithmBase::message() + + DIAGNOSTIC_LINE(SpecifiesTransposeTransfer) + DIAGNOSTIC_LINE(SpecifiesGemmBatchOptions) + DIAGNOSTIC_LINE(SpecifiesTwoStageSupport); } @@ -698,7 +698,7 @@ consteval int count_matches_bwd_wmma() { template consteval int count_matches_bwd_multi_d_wmma() { - using Alg = BwdMultiDWmmaAlgorithm; + using Alg = BwdMultiDWmmaV3Algorithm; return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9 + Alg::c10 + Alg::c11 + Alg::c12; } @@ -867,7 +867,7 @@ consteval void diagnose_bwd_weight_algorithm_signature() static_assert(Alg::is_valid(), Alg::message()); } else if constexpr (max_matches == multi_d_wmma_matches) { - using Alg = BwdMultiDWmmaAlgorithm; + using Alg = BwdMultiDWmmaV3Algorithm; static_assert(Alg::is_valid(), Alg::message()); } } diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp new file mode 100644 index 0000000000..2ba1300537 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp @@ -0,0 +1,106 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight && Is3D +struct ConvBwdWeightMultiDWmmaV3Factory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = + internal::SetGridwiseGemmPipelineVersion(); + static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); + + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits4D, "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits4D, "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits4D, "Invalid A source access order"); + static_assert(AccessOrderLimits4D, "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Layouts::DsLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Types::DsDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 533e775272..9bca017735 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -76,6 +76,7 @@ #include "ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp" namespace ck_tile::builder::factory { @@ -186,10 +187,9 @@ constexpr auto make_conv_instance() { return typename ConvBwdWeightWmmaFactory::Instance{}; } - else if constexpr (BwdMultiDWmmaAlgorithm::is_valid()) + else if constexpr (BwdMultiDWmmaV3Algorithm::is_valid()) { - static_assert(false, - "Backward weight convolution with multi-D WMMA algorithm is not yet supported."); + return typename ConvBwdWeightMultiDWmmaV3Factory::Instance{}; } else { diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp index e050ffad4e..e2bcd4a926 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle.cpp @@ -20,9 +20,9 @@ constexpr auto SIGNATURE = .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle{} - .with_thread_block(cku::ThreadBlock_256_128x128x8) + .with_thread_block(cku::ThreadBlock_64_32x32x32) .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) - .with_transfer(cku::BwdTransfer_4x64x1) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) .with_block_gemm(cku::BlockGemmDesc_v1_intrawave); diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index b126b5af8d..eb03ecfab2 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -610,7 +610,7 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, GridGemm_, Prefetch_>; using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle = - ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_, MultipleDSpecialization_>; + ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_, MultipleDSpecialization_>; } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index ad887b2491..ee4bd1f597 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -449,7 +449,7 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + << "," << to_string(static_cast>(t)); return oss.str(); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index a3b8be8bf8..dbce8e8ccf 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -50,7 +50,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3( + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d( typename GridwiseGemm::Argument karg, const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, @@ -861,7 +861,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { if(gemm_arg.KBatch > 1) { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, remove_reference_t, remove_reference_t, @@ -875,7 +875,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 } else { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, remove_reference_t, remove_reference_t, @@ -900,7 +900,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { if(gemm_arg.KBatch > 1) { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, remove_reference_t, remove_reference_t, @@ -914,7 +914,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 } else { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, remove_reference_t, remove_reference_t,